awni's commit files

This commit is contained in:
Awni Hannun
2023-11-29 10:30:41 -08:00
parent e411fcae68
commit 8ca7f9e8e9
130 changed files with 30159 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
import array
import reprlib
class FixedRepr(reprlib.Repr):
"""Only route python array instances to repr_array."""
def repr_array(self, x, maxlevel):
if isinstance(x, array.array):
return super().repr_array(x, maxlevel)
else:
return self.repr_instance(x, maxlevel)
# We need to monkey-patch reprlib so that we can use the debugger without
# renaming the array to something else
fixed_repr = FixedRepr()
reprlib.repr = fixed_repr.repr

94
python/mlx/extension.py Normal file
View File

@@ -0,0 +1,94 @@
import os
import re
import subprocess
import sys
from pathlib import Path
from setuptools import Extension, setup, find_namespace_packages
from setuptools.command.build_ext import build_ext
import mlx
_MLX_PATH = str(mlx.__path__[0])
# A CMakeExtension needs a sourcedir instead of a file list.
class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
self.sourcedir = os.fspath(Path(sourcedir).resolve())
class CMakeBuild(build_ext):
def build_extension(self, ext: CMakeExtension) -> None:
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
extdir = ext_fullpath.parent.resolve()
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release"
# CMake lets you override the generator - we need to check this.
# Can be set with Conda-Build, for example.
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
# from Python.
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
f"-DCMAKE_BUILD_TYPE={cfg}",
"-DBUILD_SHARED_LIBS=ON",
]
build_args = []
# Adding CMake arguments set as environment variable
# (needed e.g. to build for ARM OSx on conda-forge)
if "CMAKE_ARGS" in os.environ:
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
if sys.platform.startswith("darwin"):
# Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs:
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
# across all generators.
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
# self.parallel is a Python 3 only way to set parallel jobs by hand
# using -j in the build_ext call, not supported by pip or PyPA-build.
if hasattr(self, "parallel") and self.parallel:
# CMake 3.12+ only.
build_args += [f"-j{self.parallel}"]
build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists():
build_temp.mkdir(parents=True)
# Make sure cmake can find MLX
os.environ["MLX_DIR"] = _MLX_PATH
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
)
subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
)
def run(self):
super().run()
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
if self.inplace:
for ext in self.extensions:
if isinstance(ext, CMakeExtension):
# Resolve inplace package dir
build_py = self.get_finalized_command("build_py")
inplace_file, regular_file = self._get_inplace_equivalent(
build_py, ext
)
inplace_dir = str(Path(inplace_file).parent.resolve())
regular_dir = str(Path(regular_file).parent.resolve())
self.copy_tree(regular_dir, inplace_dir)

View File

@@ -0,0 +1,401 @@
import textwrap
from typing import Any, Callable, List, Union, Optional
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
class Module(dict):
"""Base class for building neural networks with MLX.
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
your models should do the same.
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
instances in arbitrary nesting of python lists or dicts. The ``Module``
then allows recursively extracting all the :class:`mlx.core.array` instances
using :meth:`mlx.nn.Module.parameters`.
In addition, the ``Module`` has the concept of trainable and non trainable
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
the gradients are returned only with respect to the trainable parameters.
All arrays in a module are trainable unless they are added in the "frozen"
set by calling :meth:`freeze`.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MyMLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
super().__init__()
self.in_proj = nn.Linear(in_dims, hidden_dims)
self.out_proj = nn.Linear(hidden_dims, out_dims)
def __call__(self, x):
x = self.in_proj(x)
x = mx.maximum(x, 0)
return self.out_proj(x)
model = MyMLP(2, 1)
# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())
# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())
"""
def __init__(self):
"""Should be called by the subclasses of ``Module``."""
self._no_grad = set()
self._training = True
@property
def training(self):
return self._training
def _extra_repr(self):
return ""
def __repr__(self):
children = tree_flatten(self.children(), is_leaf=self.is_module)
value = f"{type(self).__name__}({self._extra_repr()}"
for k, v in children:
value += "\n"
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
if children:
value += "\n"
value += ")"
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
def __setattr__(self, key: str, val: Any):
self[key] = val
def load_weights(self, file: str):
"""
Load and update the model's weights from a `.npz` file.
"""
self.update(tree_unflatten(list(mx.load(file).items())))
def save_weights(self, file: str):
"""
Save the model's weights to a `.npz` file.
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
@staticmethod
def is_module(value):
return isinstance(value, Module)
@staticmethod
def valid_child_filter(module, key, value):
return isinstance(value, (dict, list))
@staticmethod
def valid_parameter_filter(module, key, value):
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
@staticmethod
def trainable_parameter_filter(module, key, value):
return (
Module.valid_parameter_filter(module, key, value)
and key not in module._no_grad
)
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
but it can also be used to extract any subset of the module's parameters.
Args:
filter_fn (Callable): Given a value, the key in which it is found
and the containing module, decide whether to keep the value or
drop it.
map_fn (Callable, optional): Optionally transform the value before
returning it.
is_leaf_fn (Callable, optional): Given a value, the key in which it
is found and the containing module decide if it is a leaf.
Returns:
A dictionary containing the contents of the module recursively filtered
"""
map_fn = map_fn or (lambda x: x)
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module
as a dict of dicts and lists."""
return self.filter_and_map(self.valid_parameter_filter)
def trainable_parameters(self):
"""Recursively return all the non frozen :class:`mlx.core.array` members of
this Module as a dict of dicts and lists."""
return self.filter_and_map(self.trainable_parameter_filter)
def children(self):
"""Return the direct descendants of this Module instance."""
return self.filter_and_map(
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
)
def leaf_modules(self):
"""Return the submodules that do not contain other modules."""
def _is_leaf_module(m, k, v):
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict):
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
Commonly used by the optimizer to change the model to the updated
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
tracers in the model in order to compute gradients.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
updated.
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
"""
def apply(dst, parameters):
if isinstance(parameters, dict):
for k in parameters:
if k in dst:
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(parameters, list):
for i in range(len(dst)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
apply(self, parameters)
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
casts all parameters to 16 bit floats.
Args:
map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`).
"""
filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn))
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
"""Apply a function to all the modules in this instance (including this
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
"""
module_stack = [("", self)]
while module_stack:
prefix, mod = module_stack.pop()
apply_fn(prefix, mod)
prefix = "." + prefix if prefix else ""
module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
)
def modules(self):
"""Return a list with all the modules in this instance.
Returns:
A list of :class:`mlx.nn.Module` instances.
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append(m))
return modulelist
def named_modules(self):
"""Return a list with all the modules in this instance and their name
with dot notation.
Returns:
A list of tuples (str, :class:`mlx.nn.Module`).
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
return modulelist
def _validate_keys(self, keys, strict):
keys = keys if isinstance(keys, list) else [keys]
if strict:
for k in keys:
if k not in self:
raise KeyError(f"Module doesn't contain member {k}.")
return keys
def freeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent ie freezing a frozen model is a noop.
For instance to only train the attention parameters from a transformer:
model = ...
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args:
recurse (bool, optional): If True then freeze the parameters of the
submodules as well (default: True).
keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
``module.freeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
"""
def _freeze_impl(_, m):
local_keys = keys
if local_keys is None:
local_keys = tree_flatten(
m.filter_and_map(
lambda m, k, v: (not isinstance(v, Module))
and m.valid_parameter_filter(m, k, v)
)
)
local_keys = [k for (k, v) in local_keys]
local_keys = m._validate_keys(local_keys, strict)
m._no_grad.update(local_keys)
if recurse:
self.apply_to_modules(_freeze_impl)
else:
_freeze_impl("", self)
def unfreeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
For instance to only train the biases one can do:
model = ...
model.freeze()
model.unfreeze(keys="bias")
Args:
recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well (default: True).
keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
"""
def _unfreeze_impl(_, m):
if keys is None:
m._no_grad.clear()
else:
local_keys = m._validate_keys(keys, strict)
m._no_grad.difference_update(local_keys)
if recurse:
self.apply_to_modules(_unfreeze_impl)
else:
_unfreeze_impl("", self)
def train(self, mode: bool = True):
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
def eval(self):
self.train(False)

View File

@@ -0,0 +1,22 @@
from mlx.nn.layers.base import Module
class Sequential(Module):
"""A layer that calls the passed callables in order.
We can pass either modules or plain callables to the Sequential module. If
our functions have learnable parameters they should be implemented as
``nn.Module`` instances.
Args:
modules (tuple of Callables): The modules to call in order
"""
def __init__(self, *modules):
super().__init__()
self.layers = list(modules)
def __call__(self, x):
for m in self.layers:
x = m(x)
return x

View File

@@ -0,0 +1,33 @@
import mlx.core as mx
from mlx.nn.layers.base import Module
class Dropout(Module):
"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
expected value of a given element will remain the same.
Args:
p (float): The probability to zero an element
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError("The dropout probability should be in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if self._p_1 == 1 or not self.training:
return x
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask.astype(x.dtype) * x

View File

@@ -0,0 +1,178 @@
import mlx.core as mx
from mlx.nn.layers.base import Module
class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
[1]: https://arxiv.org/abs/1607.06450
Args:
dims (int): The feature dimension of the input to normalize over
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization
"""
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
class RMSNorm(Module):
r"""Applies Root Mean Square normalization [1] to the inputs.
Computes
.. math::
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
where :math:`\gamma` is a learned per feature dimension parameter initialized at
1.
[1]: https://arxiv.org/abs/1910.07467
Args:
dims (int): The feature dimension of the input to normalize over
eps (float): A small additive constant for numerical stability
"""
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
# S is 1/sqrt(N) where N is the size of the features of x and is used
# to compute a numerically more stable RMS of x by multiplying with S
# first and summing.
#
# This way we prefer underflow over overflow which is controlled with
# the parameter epsilon anyway.
S = 1 / x.shape[-1] ** 0.5
n = (x * S).square().sum(axis=-1, keepdims=True)
n = mx.rsqrt(n + self.eps)
return self.weight * x * n
class GroupNorm(Module):
r"""Applies Group Normalization [1] to the inputs.
Computes the same normalization as layer norm, namely
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. However, the mean and
variance are computed over the spatial dimensions and each group of
features. In particular, the input is split into num_groups accross the
feature dimension.
The feature dimension is assumed to be the last dimension and the dimensions
that precede it (except the first) are considered the spatial dimensions.
[1]: https://arxiv.org/abs/1803.08494
Args:
num_groups (int): Number of groups to separate the features into
dims (int): The feature dimensions of the input to normalize over
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization.
pytorch_compatible (bool): If True perform the group normalization in
the same order/grouping as PyTorch.
"""
def __init__(
self,
num_groups: int,
dims: int,
eps: float = 1e-5,
affine: bool = True,
pytorch_compatible: bool = False,
):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.num_groups = num_groups
self.dims = dims
self.eps = eps
self.pytorch_compatible = pytorch_compatible
def _extra_repr(self):
return (
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
)
def _pytorch_compatible_group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
# Split into groups
x = x.reshape(batch, -1, num_groups, dims // num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, -1, dims // num_groups, num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
return x
def _group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
# Split into groups
x = x.reshape(batch, -1, num_groups)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, *rest, dims)
return x
def __call__(self, x):
group_norm = (
self._pytorch_compatible_group_norm
if self.pytorch_compatible
else self._group_norm
)
x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x

View File

@@ -0,0 +1,142 @@
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
class RoPE(Module):
"""Implements the rotary positional encoding [1].
The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with
stride half the feature dimensions for efficiency.
[1]: https://arxiv.org/abs/2104.09864
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to True choose the traditional
implementation which is slightly less efficient.
"""
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype)
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
return costheta, sintheta
class SinusoidalPositionalEncoding(Module):
"""Implements sinusoidal positional encoding similar to [1].
[1]: https://arxiv.org/abs/1706.03762
Args:
dims (int): The dimensionality of the resulting positional embeddings.
min_freq (float): The minimum frequency expected (default: 0.0001)
max_freq (float): The maximum frequency expected (default: 1)
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
instead of the other way around (default: False)
full_turns (bool): If set to True multiply the frequencies
with ``2 pi`` (default: False)
"""
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__()
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
min_freq = math.log(min_freq)
max_freq = math.log(max_freq)
# Start with underscore so it is not included in the parameters
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
if full_turns:
self._sigmas = self._sigmas * (2 * math.pi)
# Save some constants that define the implementation
self.scale = scale or (2 / dims) ** 0.5
self.cos_first = cos_first
def __call__(self, x):
y = x[..., None] * self._sigmas
cosy = mx.cos(y)
siny = mx.sin(y)
if self.cos_first:
y = mx.concatenate([cosy, siny], axis=-1)
else:
y = mx.concatenate([siny, cosy], axis=-1)
if self.scale != 1:
y = y * self.scale
return y

View File

@@ -0,0 +1,136 @@
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
new values by aggregating information from the input values according to
the similarities of the input queries and keys.
All inputs as well as the output are lineary projected without biases.
MultiHeadAttention also expects an additive attention mask that should be
broadcastable with (batch, num_heads, # queries, # keys). The mask should
have ``-inf`` or very negative numbers to the positions that should *not* be
attended to.
Args:
dims (int): The model dimensions. If no other dims are provided then
dims is used for queries, keys, values and the output.
num_heads (int): How many attention heads to use
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, False)
self.key_proj = Linear(key_input_dims, dims, False)
self.value_proj = Linear(value_input_dims, value_dims, False)
self.out_proj = Linear(value_dims, value_output_dims, False)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
def __call__(self, x, mask):
y = self.ln1(x)
y = self.attention(y, y, y, mask)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
x = self.ln(x)
return x

31
python/mlx/nn/utils.py Normal file
View File

@@ -0,0 +1,31 @@
from typing import Callable
import mlx.core as mx
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
Args:
model (mlx.nn.Module): The model whose trainable parameters to compute
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
A callable that returns the value of ``fn`` and the gradients wrt the
trainable parameters of ``model``
"""
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(*args, **kwargs)
value_grad_fn = mx.value_and_grad(inner_fn)
def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad
return wrapped_value_grad_fn

152
python/mlx/optimizers.py Normal file
View File

@@ -0,0 +1,152 @@
import math
from typing import List
import mlx.core as mx
from mlx.utils import tree_map
class OptimizerState(dict):
"""The optimizer state implements a recursively defined
:class:`collections.defaultdict`, namely a missing key in an optimizer
state is an :class:`OptimizerState`.
.. note::
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
the key to the ``default`` value if the ``key`` was not present in the
dictionary.
"""
def __getitem__(self, key):
if key not in self:
self[key] = OptimizerState()
return super().__getitem__(key)
def get(self, key, default):
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
if key not in self:
self[key] = default
return super().__getitem__(key)
class Optimizer:
"""The base class for all optimizers. It allows us to implement an
optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes:
state (OptimizerState): It holds the optimizer's state dictionary.
"""
def __init__(self):
self.state = OptimizerState()
def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the
model with the new parameters.
Args:
model (mlx.nn.Module): An mlx module to be updated.
gradients (dict): A Python tree of gradients, most likely computed
via :func:`mlx.nn.value_and_grad`.
"""
model.update(self.apply_gradients(gradients, model))
def apply_gradients(self, gradients: dict, model: dict):
"""Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via
``model.update(opt.apply_gradients(grads, model))`` which is precisely
how :meth:`Optimizer.update` is implemented.
Args:
gradients (dict): A Python tree of gradients.
model (dict): A Python tree of parameters. It can be a superset of
the gradients. In that case the returned python tree
will be of the same structure as the gradients.
"""
return tree_map(self.apply_single, gradients, model, self.state)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""To be extended by the children classes to implement each optimizer's
update."""
raise NotImplementedError()
class SGD(Optimizer):
r"""Stochastic gradient descent optimizer.
Updates a parameter :math:`w` with a gradient :math:`g` as follows
.. math::
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
w_{t+1} &= w_t - \lambda v_{t+1}
Args:
learning_rate (float): The learning :math:`\lambda` for the update
momentum (float): The momentum strength :math:`\mu`
"""
def __init__(self, learning_rate: float, momentum: float = 0.0):
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the SGD parameter update and stores :math:`v` in the
optimizer state."""
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
v = state.get("v", mx.zeros_like(gradient))
v = self.momentum * v + (1 - self.momentum) * gradient
state["v"] = v
return parameter - self.learning_rate * v
class Adam(Optimizer):
r"""Implementation of the Adam optimizer [1].
Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
.. math::
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
optimization. ICLR 2015.
"""
def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
):
super().__init__()
self.learning_rate = learning_rate
self.betas = betas
self.eps = eps
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adam parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
b1, b2 = self.betas
eps = self.eps
m = state.get("m", gradient)
v = state.get("v", mx.square(gradient))
m = b1 * m + (1 - b1) * gradient
v = b2 * v + (1 - b2) * mx.square(gradient)
state["m"] = m
state["v"] = v
return parameter - lr * m / (mx.sqrt(v) + eps)