mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
awni's commit files
This commit is contained in:
18
python/mlx/_reprlib_fix.py
Normal file
18
python/mlx/_reprlib_fix.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import array
|
||||
import reprlib
|
||||
|
||||
|
||||
class FixedRepr(reprlib.Repr):
|
||||
"""Only route python array instances to repr_array."""
|
||||
|
||||
def repr_array(self, x, maxlevel):
|
||||
if isinstance(x, array.array):
|
||||
return super().repr_array(x, maxlevel)
|
||||
else:
|
||||
return self.repr_instance(x, maxlevel)
|
||||
|
||||
|
||||
# We need to monkey-patch reprlib so that we can use the debugger without
|
||||
# renaming the array to something else
|
||||
fixed_repr = FixedRepr()
|
||||
reprlib.repr = fixed_repr.repr
|
94
python/mlx/extension.py
Normal file
94
python/mlx/extension.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import Extension, setup, find_namespace_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
import mlx
|
||||
|
||||
_MLX_PATH = str(mlx.__path__[0])
|
||||
|
||||
|
||||
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name: str, sourcedir: str = "") -> None:
|
||||
super().__init__(name, sources=[])
|
||||
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
def build_extension(self, ext: CMakeExtension) -> None:
|
||||
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
|
||||
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
|
||||
extdir = ext_fullpath.parent.resolve()
|
||||
|
||||
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
||||
cfg = "Debug" if debug else "Release"
|
||||
|
||||
# CMake lets you override the generator - we need to check this.
|
||||
# Can be set with Conda-Build, for example.
|
||||
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
||||
|
||||
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
||||
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
||||
# from Python.
|
||||
cmake_args = [
|
||||
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||
f"-DCMAKE_BUILD_TYPE={cfg}",
|
||||
"-DBUILD_SHARED_LIBS=ON",
|
||||
]
|
||||
build_args = []
|
||||
# Adding CMake arguments set as environment variable
|
||||
# (needed e.g. to build for ARM OSx on conda-forge)
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||
|
||||
if sys.platform.startswith("darwin"):
|
||||
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||
if archs:
|
||||
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
|
||||
|
||||
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||
# across all generators.
|
||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
||||
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
||||
if hasattr(self, "parallel") and self.parallel:
|
||||
# CMake 3.12+ only.
|
||||
build_args += [f"-j{self.parallel}"]
|
||||
|
||||
build_temp = Path(self.build_temp) / ext.name
|
||||
if not build_temp.exists():
|
||||
build_temp.mkdir(parents=True)
|
||||
|
||||
# Make sure cmake can find MLX
|
||||
os.environ["MLX_DIR"] = _MLX_PATH
|
||||
|
||||
subprocess.run(
|
||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||
)
|
||||
subprocess.run(
|
||||
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
||||
)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
|
||||
if self.inplace:
|
||||
for ext in self.extensions:
|
||||
if isinstance(ext, CMakeExtension):
|
||||
# Resolve inplace package dir
|
||||
build_py = self.get_finalized_command("build_py")
|
||||
inplace_file, regular_file = self._get_inplace_equivalent(
|
||||
build_py, ext
|
||||
)
|
||||
|
||||
inplace_dir = str(Path(inplace_file).parent.resolve())
|
||||
regular_dir = str(Path(regular_file).parent.resolve())
|
||||
|
||||
self.copy_tree(regular_dir, inplace_dir)
|
401
python/mlx/nn/layers/base.py
Normal file
401
python/mlx/nn/layers/base.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Union, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
||||
your models should do the same.
|
||||
|
||||
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
||||
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
||||
then allows recursively extracting all the :class:`mlx.core.array` instances
|
||||
using :meth:`mlx.nn.Module.parameters`.
|
||||
|
||||
In addition, the ``Module`` has the concept of trainable and non trainable
|
||||
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
|
||||
the gradients are returned only with respect to the trainable parameters.
|
||||
All arrays in a module are trainable unless they are added in the "frozen"
|
||||
set by calling :meth:`freeze`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
class MyMLP(nn.Module):
|
||||
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
||||
super().__init__()
|
||||
|
||||
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
||||
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.in_proj(x)
|
||||
x = mx.maximum(x, 0)
|
||||
return self.out_proj(x)
|
||||
|
||||
model = MyMLP(2, 1)
|
||||
|
||||
# All the model parameters are created but since MLX is lazy by
|
||||
# default, they are not evaluated yet. Calling `mx.eval` actually
|
||||
# allocates memory and initializes the parameters.
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Setting a parameter to a new value is as simply as accessing that
|
||||
# parameter and assigning a new array to it.
|
||||
model.in_proj.weight = model.in_proj.weight * 2
|
||||
mx.eval(model.parameters())
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Should be called by the subclasses of ``Module``."""
|
||||
self._no_grad = set()
|
||||
self._training = True
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
return self._training
|
||||
|
||||
def _extra_repr(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
children = tree_flatten(self.children(), is_leaf=self.is_module)
|
||||
value = f"{type(self).__name__}({self._extra_repr()}"
|
||||
for k, v in children:
|
||||
value += "\n"
|
||||
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
|
||||
if children:
|
||||
value += "\n"
|
||||
value += ")"
|
||||
|
||||
return value
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
|
||||
def load_weights(self, file: str):
|
||||
"""
|
||||
Load and update the model's weights from a `.npz` file.
|
||||
"""
|
||||
self.update(tree_unflatten(list(mx.load(file).items())))
|
||||
|
||||
def save_weights(self, file: str):
|
||||
"""
|
||||
Save the model's weights to a `.npz` file.
|
||||
"""
|
||||
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
||||
|
||||
@staticmethod
|
||||
def is_module(value):
|
||||
return isinstance(value, Module)
|
||||
|
||||
@staticmethod
|
||||
def valid_child_filter(module, key, value):
|
||||
return isinstance(value, (dict, list))
|
||||
|
||||
@staticmethod
|
||||
def valid_parameter_filter(module, key, value):
|
||||
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
||||
|
||||
@staticmethod
|
||||
def trainable_parameter_filter(module, key, value):
|
||||
return (
|
||||
Module.valid_parameter_filter(module, key, value)
|
||||
and key not in module._no_grad
|
||||
)
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
map_fn: Optional[Callable] = None,
|
||||
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||
namely only select keys and values where ``filter_fn`` returns true.
|
||||
|
||||
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
||||
but it can also be used to extract any subset of the module's parameters.
|
||||
|
||||
Args:
|
||||
filter_fn (Callable): Given a value, the key in which it is found
|
||||
and the containing module, decide whether to keep the value or
|
||||
drop it.
|
||||
map_fn (Callable, optional): Optionally transform the value before
|
||||
returning it.
|
||||
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
||||
is found and the containing module decide if it is a leaf.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the contents of the module recursively filtered
|
||||
"""
|
||||
|
||||
map_fn = map_fn or (lambda x: x)
|
||||
is_leaf_fn = is_leaf_fn or (
|
||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||
)
|
||||
|
||||
def unwrap(vk, v):
|
||||
if is_leaf_fn(self, vk, v):
|
||||
return map_fn(v)
|
||||
|
||||
if isinstance(v, Module):
|
||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||
|
||||
if isinstance(v, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{vk}.{k}"
|
||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||
return nd
|
||||
|
||||
if isinstance(v, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(v):
|
||||
tk = f"{vk}.{i}"
|
||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||
|
||||
def parameters(self):
|
||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||
as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.valid_parameter_filter)
|
||||
|
||||
def trainable_parameters(self):
|
||||
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
||||
this Module as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.trainable_parameter_filter)
|
||||
|
||||
def children(self):
|
||||
"""Return the direct descendants of this Module instance."""
|
||||
return self.filter_and_map(
|
||||
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
|
||||
)
|
||||
|
||||
def leaf_modules(self):
|
||||
"""Return the submodules that do not contain other modules."""
|
||||
|
||||
def _is_leaf_module(m, k, v):
|
||||
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict):
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
Commonly used by the optimizer to change the model to the updated
|
||||
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
|
||||
tracers in the model in order to compute gradients.
|
||||
|
||||
The passed in parameters dictionary need not be a full dictionary
|
||||
similar to :meth:`parameters`. Only the provided locations will be
|
||||
updated.
|
||||
|
||||
Args:
|
||||
parameters (dict): A complete or partial dictionary of the modules
|
||||
parameters.
|
||||
"""
|
||||
|
||||
def apply(dst, parameters):
|
||||
if isinstance(parameters, dict):
|
||||
for k in parameters:
|
||||
if k in dst:
|
||||
current_value = dst[k]
|
||||
new_value = parameters[k]
|
||||
if isinstance(current_value, mx.array):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(dst)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, Module):
|
||||
current_value.update(new_value)
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
|
||||
apply(self, parameters)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
):
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
||||
casts all parameters to 16 bit floats.
|
||||
|
||||
Args:
|
||||
map_fn (Callable): Maps an array to another array
|
||||
filter_fn (Callable, optional): Filter to select which arrays to
|
||||
map (default: :meth:`Module.valid_parameter_filter`).
|
||||
"""
|
||||
filter_fn = filter_fn or Module.valid_parameter_filter
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
Args:
|
||||
apply_fn (Callable): The function to apply to the modules.
|
||||
"""
|
||||
module_stack = [("", self)]
|
||||
while module_stack:
|
||||
prefix, mod = module_stack.pop()
|
||||
apply_fn(prefix, mod)
|
||||
prefix = "." + prefix if prefix else ""
|
||||
module_stack.extend(
|
||||
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
||||
)
|
||||
|
||||
def modules(self):
|
||||
"""Return a list with all the modules in this instance.
|
||||
|
||||
Returns:
|
||||
A list of :class:`mlx.nn.Module` instances.
|
||||
"""
|
||||
modulelist = []
|
||||
self.apply_to_modules(lambda k, m: modulelist.append(m))
|
||||
return modulelist
|
||||
|
||||
def named_modules(self):
|
||||
"""Return a list with all the modules in this instance and their name
|
||||
with dot notation.
|
||||
|
||||
Returns:
|
||||
A list of tuples (str, :class:`mlx.nn.Module`).
|
||||
"""
|
||||
modulelist = []
|
||||
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
|
||||
return modulelist
|
||||
|
||||
def _validate_keys(self, keys, strict):
|
||||
keys = keys if isinstance(keys, list) else [keys]
|
||||
if strict:
|
||||
for k in keys:
|
||||
if k not in self:
|
||||
raise KeyError(f"Module doesn't contain member {k}.")
|
||||
return keys
|
||||
|
||||
def freeze(
|
||||
self,
|
||||
*,
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
This function is idempotent ie freezing a frozen model is a noop.
|
||||
|
||||
For instance to only train the attention parameters from a transformer:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then freeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be frozen otherwise all the parameters of a
|
||||
module. For instance freeze all biases by calling
|
||||
``module.freeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def _freeze_impl(_, m):
|
||||
local_keys = keys
|
||||
if local_keys is None:
|
||||
local_keys = tree_flatten(
|
||||
m.filter_and_map(
|
||||
lambda m, k, v: (not isinstance(v, Module))
|
||||
and m.valid_parameter_filter(m, k, v)
|
||||
)
|
||||
)
|
||||
local_keys = [k for (k, v) in local_keys]
|
||||
|
||||
local_keys = m._validate_keys(local_keys, strict)
|
||||
m._no_grad.update(local_keys)
|
||||
|
||||
if recurse:
|
||||
self.apply_to_modules(_freeze_impl)
|
||||
else:
|
||||
_freeze_impl("", self)
|
||||
|
||||
def unfreeze(
|
||||
self,
|
||||
*,
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
):
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
a noop.
|
||||
|
||||
For instance to only train the biases one can do:
|
||||
|
||||
model = ...
|
||||
model.freeze()
|
||||
model.unfreeze(keys="bias")
|
||||
|
||||
Args:
|
||||
recurse (bool, optional): If True then unfreeze the parameters of the
|
||||
submodules as well (default: True).
|
||||
keys (str or list[str], optional): If provided then only these
|
||||
parameters will be unfrozen otherwise all the parameters of a
|
||||
module. For instance unfreeze all biases by calling
|
||||
``module.unfreeze(keys="bias")``.
|
||||
strict (bool, optional): If set to True validate that the passed keys exist
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def _unfreeze_impl(_, m):
|
||||
if keys is None:
|
||||
m._no_grad.clear()
|
||||
|
||||
else:
|
||||
local_keys = m._validate_keys(keys, strict)
|
||||
m._no_grad.difference_update(local_keys)
|
||||
|
||||
if recurse:
|
||||
self.apply_to_modules(_unfreeze_impl)
|
||||
else:
|
||||
_unfreeze_impl("", self)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
def _set_train(_, m):
|
||||
m._training = mode
|
||||
|
||||
self.apply_to_modules(_set_train)
|
||||
|
||||
def eval(self):
|
||||
self.train(False)
|
22
python/mlx/nn/layers/containers.py
Normal file
22
python/mlx/nn/layers/containers.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Sequential(Module):
|
||||
"""A layer that calls the passed callables in order.
|
||||
|
||||
We can pass either modules or plain callables to the Sequential module. If
|
||||
our functions have learnable parameters they should be implemented as
|
||||
``nn.Module`` instances.
|
||||
|
||||
Args:
|
||||
modules (tuple of Callables): The modules to call in order
|
||||
"""
|
||||
|
||||
def __init__(self, *modules):
|
||||
super().__init__()
|
||||
self.layers = list(modules)
|
||||
|
||||
def __call__(self, x):
|
||||
for m in self.layers:
|
||||
x = m(x)
|
||||
return x
|
33
python/mlx/nn/layers/dropout.py
Normal file
33
python/mlx/nn/layers/dropout.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Dropout(Module):
|
||||
"""Randomly zero a portion of the elements during training.
|
||||
|
||||
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
||||
:math:`p` is the probability of zeroing an element. This is done so the
|
||||
expected value of a given element will remain the same.
|
||||
|
||||
Args:
|
||||
p (float): The probability to zero an element
|
||||
"""
|
||||
|
||||
def __init__(self, p: float = 0.5):
|
||||
super().__init__()
|
||||
|
||||
if p < 0 or p >= 1:
|
||||
raise ValueError("The dropout probability should be in [0, 1)")
|
||||
|
||||
self._p_1 = 1 - p
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"p={1-self._p_1}"
|
||||
|
||||
def __call__(self, x):
|
||||
if self._p_1 == 1 or not self.training:
|
||||
return x
|
||||
|
||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||
|
||||
return (1 / self._p_1) * mask.astype(x.dtype) * x
|
178
python/mlx/nn/layers/normalization.py
Normal file
178
python/mlx/nn/layers/normalization.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies layer normalization [1] on the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively.
|
||||
|
||||
[1]: https://arxiv.org/abs/1607.06450
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimension of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
r"""Applies Root Mean Square normalization [1] to the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
|
||||
|
||||
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
||||
1.
|
||||
|
||||
[1]: https://arxiv.org/abs/1910.07467
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimension of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||
|
||||
def __call__(self, x):
|
||||
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
||||
# to compute a numerically more stable RMS of x by multiplying with S
|
||||
# first and summing.
|
||||
#
|
||||
# This way we prefer underflow over overflow which is controlled with
|
||||
# the parameter epsilon anyway.
|
||||
S = 1 / x.shape[-1] ** 0.5
|
||||
|
||||
n = (x * S).square().sum(axis=-1, keepdims=True)
|
||||
n = mx.rsqrt(n + self.eps)
|
||||
|
||||
return self.weight * x * n
|
||||
|
||||
|
||||
class GroupNorm(Module):
|
||||
r"""Applies Group Normalization [1] to the inputs.
|
||||
|
||||
Computes the same normalization as layer norm, namely
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively. However, the mean and
|
||||
variance are computed over the spatial dimensions and each group of
|
||||
features. In particular, the input is split into num_groups accross the
|
||||
feature dimension.
|
||||
|
||||
The feature dimension is assumed to be the last dimension and the dimensions
|
||||
that precede it (except the first) are considered the spatial dimensions.
|
||||
|
||||
[1]: https://arxiv.org/abs/1803.08494
|
||||
|
||||
Args:
|
||||
num_groups (int): Number of groups to separate the features into
|
||||
dims (int): The feature dimensions of the input to normalize over
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization.
|
||||
pytorch_compatible (bool): If True perform the group normalization in
|
||||
the same order/grouping as PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_groups: int,
|
||||
dims: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = True,
|
||||
pytorch_compatible: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.num_groups = num_groups
|
||||
self.dims = dims
|
||||
self.eps = eps
|
||||
self.pytorch_compatible = pytorch_compatible
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
|
||||
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
|
||||
)
|
||||
|
||||
def _pytorch_compatible_group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups, dims // num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, -1, dims // num_groups, num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
|
||||
|
||||
return x
|
||||
|
||||
def _group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, *rest, dims)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x):
|
||||
group_norm = (
|
||||
self._pytorch_compatible_group_norm
|
||||
if self.pytorch_compatible
|
||||
else self._group_norm
|
||||
)
|
||||
x = group_norm(x)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
142
python/mlx/nn/layers/positional_encoding.py
Normal file
142
python/mlx/nn/layers/positional_encoding.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class RoPE(Module):
|
||||
"""Implements the rotary positional encoding [1].
|
||||
|
||||
The traditional implementation rotates consecutive pairs of elements in the
|
||||
feature dimension while the default implementation rotates pairs with
|
||||
stride half the feature dimensions for efficiency.
|
||||
|
||||
[1]: https://arxiv.org/abs/2104.09864
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to True choose the traditional
|
||||
implementation which is slightly less efficient.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||
else:
|
||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32
|
||||
):
|
||||
D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype)
|
||||
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta = mx.cos(theta)
|
||||
sintheta = mx.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
"""Implements sinusoidal positional encoding similar to [1].
|
||||
|
||||
[1]: https://arxiv.org/abs/1706.03762
|
||||
|
||||
Args:
|
||||
dims (int): The dimensionality of the resulting positional embeddings.
|
||||
min_freq (float): The minimum frequency expected (default: 0.0001)
|
||||
max_freq (float): The maximum frequency expected (default: 1)
|
||||
scale (float): Scale the embeddings by that number (default: sqrt(dims//2))
|
||||
cos_first (bool): If set to True embed using ``[cos(x); sin(x)]``
|
||||
instead of the other way around (default: False)
|
||||
full_turns (bool): If set to True multiply the frequencies
|
||||
with ``2 pi`` (default: False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
min_freq: float = 0.0001,
|
||||
max_freq: float = 1,
|
||||
scale: Optional[float] = None,
|
||||
cos_first: bool = False,
|
||||
full_turns: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
|
||||
min_freq = math.log(min_freq)
|
||||
max_freq = math.log(max_freq)
|
||||
|
||||
# Start with underscore so it is not included in the parameters
|
||||
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
|
||||
if full_turns:
|
||||
self._sigmas = self._sigmas * (2 * math.pi)
|
||||
|
||||
# Save some constants that define the implementation
|
||||
self.scale = scale or (2 / dims) ** 0.5
|
||||
self.cos_first = cos_first
|
||||
|
||||
def __call__(self, x):
|
||||
y = x[..., None] * self._sigmas
|
||||
cosy = mx.cos(y)
|
||||
siny = mx.sin(y)
|
||||
|
||||
if self.cos_first:
|
||||
y = mx.concatenate([cosy, siny], axis=-1)
|
||||
else:
|
||||
y = mx.concatenate([siny, cosy], axis=-1)
|
||||
|
||||
if self.scale != 1:
|
||||
y = y * self.scale
|
||||
|
||||
return y
|
136
python/mlx/nn/layers/transformer.py
Normal file
136
python/mlx/nn/layers/transformer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import LayerNorm
|
||||
|
||||
|
||||
class MultiHeadAttention(Module):
|
||||
"""Implements the scaled dot product attention with multiple heads.
|
||||
|
||||
Given inputs for queries, keys and values the ``MultiHeadAttention`` produces
|
||||
new values by aggregating information from the input values according to
|
||||
the similarities of the input queries and keys.
|
||||
|
||||
All inputs as well as the output are lineary projected without biases.
|
||||
|
||||
MultiHeadAttention also expects an additive attention mask that should be
|
||||
broadcastable with (batch, num_heads, # queries, # keys). The mask should
|
||||
have ``-inf`` or very negative numbers to the positions that should *not* be
|
||||
attended to.
|
||||
|
||||
Args:
|
||||
dims (int): The model dimensions. If no other dims are provided then
|
||||
dims is used for queries, keys, values and the output.
|
||||
num_heads (int): How many attention heads to use
|
||||
query_input_dims (int, optional): The input dimensions of the queries (default: dims).
|
||||
key_input_dims (int, optional): The input dimensions of the keys (default: dims).
|
||||
value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims).
|
||||
value_dims (int, optional): The dimensions of the values after the projection (default: dims).
|
||||
value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.query_proj = Linear(query_input_dims, dims, False)
|
||||
self.key_proj = Linear(key_input_dims, dims, False)
|
||||
self.value_proj = Linear(value_input_dims, value_dims, False)
|
||||
self.out_proj = Linear(value_dims, value_output_dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
@staticmethod
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
# TODO: Should replace this with finfo(dtype).min
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
class TransformerEncoderLayer(Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = MultiHeadAttention(dims, num_heads)
|
||||
self.ln1 = LayerNorm(dims)
|
||||
self.ln2 = LayerNorm(dims)
|
||||
self.linear1 = Linear(dims, mlp_dims)
|
||||
self.linear2 = Linear(mlp_dims, dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y = self.attention(y, y, y, mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
31
python/mlx/nn/utils.py
Normal file
31
python/mlx/nn/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
||||
"""Transform the passed function ``fn`` to a function that computes the
|
||||
gradients of ``fn`` wrt the model's trainable parameters and also its
|
||||
value.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): The model whose trainable parameters to compute
|
||||
gradients for
|
||||
fn (Callable): The scalar function to compute gradients for
|
||||
|
||||
Returns:
|
||||
A callable that returns the value of ``fn`` and the gradients wrt the
|
||||
trainable parameters of ``model``
|
||||
"""
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
model.update(params)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
value_grad_fn = mx.value_and_grad(inner_fn)
|
||||
|
||||
def wrapped_value_grad_fn(*args, **kwargs):
|
||||
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
||||
return value, grad
|
||||
|
||||
return wrapped_value_grad_fn
|
152
python/mlx/optimizers.py
Normal file
152
python/mlx/optimizers.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class OptimizerState(dict):
|
||||
"""The optimizer state implements a recursively defined
|
||||
:class:`collections.defaultdict`, namely a missing key in an optimizer
|
||||
state is an :class:`OptimizerState`.
|
||||
|
||||
.. note::
|
||||
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
|
||||
the key to the ``default`` value if the ``key`` was not present in the
|
||||
dictionary.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = OptimizerState()
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default):
|
||||
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
|
||||
if key not in self:
|
||||
self[key] = default
|
||||
return super().__getitem__(key)
|
||||
|
||||
|
||||
class Optimizer:
|
||||
"""The base class for all optimizers. It allows us to implement an
|
||||
optimizer on a per-parameter basis and apply it to a parameter tree.
|
||||
|
||||
Attributes:
|
||||
state (OptimizerState): It holds the optimizer's state dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.state = OptimizerState()
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
model with the new parameters.
|
||||
|
||||
Args:
|
||||
model (mlx.nn.Module): An mlx module to be updated.
|
||||
gradients (dict): A Python tree of gradients, most likely computed
|
||||
via :func:`mlx.nn.value_and_grad`.
|
||||
"""
|
||||
model.update(self.apply_gradients(gradients, model))
|
||||
|
||||
def apply_gradients(self, gradients: dict, model: dict):
|
||||
"""Apply the gradients to the parameters and return the updated parameters.
|
||||
|
||||
Can be used to update a model via
|
||||
``model.update(opt.apply_gradients(grads, model))`` which is precisely
|
||||
how :meth:`Optimizer.update` is implemented.
|
||||
|
||||
Args:
|
||||
gradients (dict): A Python tree of gradients.
|
||||
model (dict): A Python tree of parameters. It can be a superset of
|
||||
the gradients. In that case the returned python tree
|
||||
will be of the same structure as the gradients.
|
||||
"""
|
||||
return tree_map(self.apply_single, gradients, model, self.state)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""To be extended by the children classes to implement each optimizer's
|
||||
update."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""Stochastic gradient descent optimizer.
|
||||
|
||||
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
||||
|
||||
.. math::
|
||||
|
||||
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
|
||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning :math:`\lambda` for the update
|
||||
momentum (float): The momentum strength :math:`\mu`
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, momentum: float = 0.0):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the SGD parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
if self.momentum <= 0:
|
||||
return parameter - self.learning_rate * gradient
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = self.momentum * v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
return parameter - self.learning_rate * v
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""Implementation of the Adam optimizer [1].
|
||||
|
||||
Our Adam implementation follows the original paper and omits the bias
|
||||
correction in the first and second moment estimates. In detail,
|
||||
|
||||
.. math::
|
||||
|
||||
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
||||
|
||||
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||
optimization. ICLR 2015.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the Adam parameter update and stores :math:`v` and
|
||||
:math:`m` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
|
||||
m = state.get("m", gradient)
|
||||
v = state.get("v", mx.square(gradient))
|
||||
m = b1 * m + (1 - b1) * gradient
|
||||
v = b2 * v + (1 - b2) * mx.square(gradient)
|
||||
state["m"] = m
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
Reference in New Issue
Block a user