mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
angelos's commit files
This commit is contained in:
3
python/mlx/nn/__init__.py
Normal file
3
python/mlx/nn/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from mlx.nn.layers import *
|
||||
from mlx.nn import losses
|
||||
from mlx.nn.utils import value_and_grad
|
23
python/mlx/nn/layers/__init__.py
Normal file
23
python/mlx/nn/layers/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.activations import (
|
||||
GELU,
|
||||
ReLU,
|
||||
SiLU,
|
||||
gelu,
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
relu,
|
||||
silu,
|
||||
)
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
129
python/mlx/nn/layers/activations.py
Normal file
129
python/mlx/nn/layers/activations.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _make_activation_module(f):
|
||||
def decorator(klass):
|
||||
klass.__doc__ = f.__doc__
|
||||
klass.__call__ = lambda self, x: f(x)
|
||||
return klass
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def relu(x):
|
||||
"""Applies the Rectified Linear Unit.
|
||||
|
||||
Simply ``mx.maximum(x, 0)``.
|
||||
"""
|
||||
return mx.maximum(x, 0)
|
||||
|
||||
|
||||
def silu(x):
|
||||
r"""Applies the Sigmoid Linear Unit.
|
||||
|
||||
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
|
||||
the logistic sigmoid.
|
||||
"""
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
|
||||
approximations.
|
||||
"""
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
|
||||
def gelu_approx(x):
|
||||
r"""An approximation to Gaussian Error Linear Unit.
|
||||
|
||||
See :func:`gelu` for the exact computation.
|
||||
|
||||
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||
0.0003` in the range :math:`[-6, 6]` using the following
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
"""
|
||||
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
||||
|
||||
|
||||
def gelu_fast_approx(x):
|
||||
r"""A fast approximation to Gaussian Error Linear Unit.
|
||||
|
||||
See :func:`gelu` for the exact computation.
|
||||
|
||||
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||
0.015` in the range :math:`[-6, 6]` using the following
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.773 x\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
"""
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
|
||||
|
||||
@_make_activation_module(relu)
|
||||
class ReLU(Module):
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(silu)
|
||||
class SiLU(Module):
|
||||
pass
|
||||
|
||||
|
||||
class GELU(Module):
|
||||
r"""Applies the Gaussian Error Linear Units.
|
||||
|
||||
.. math::
|
||||
\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
However, if ``approx`` is set to 'precise' or 'fast' it applies
|
||||
|
||||
.. math::
|
||||
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
|
||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
||||
|
||||
respectively.
|
||||
|
||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||
functional equivalents and information regarding error bounds.
|
||||
|
||||
Args:
|
||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||
"""
|
||||
|
||||
def __init__(self, approx="none"):
|
||||
super().__init__()
|
||||
|
||||
if approx == "none":
|
||||
self._act = gelu
|
||||
elif approx == "precise":
|
||||
self._act = gelu_approx
|
||||
elif approx == "fast":
|
||||
self._act = gelu_fast_approx
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self._act(x)
|
6
python/mlx/nn/losses.py
Normal file
6
python/mlx/nn/losses.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
|
||||
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
||||
return mx.logsumexp(logits, axis=axis) - score
|
136
python/mlx/utils.py
Normal file
136
python/mlx/utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
def tree_map(fn, tree, *rest):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
||||
and the corresponding leaves are provided as extra positional arguments to
|
||||
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
||||
than to :func:`map`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
|
||||
model = nn.Linear(10, 10)
|
||||
print(model.parameters().keys())
|
||||
# dict_keys(['weight', 'bias'])
|
||||
|
||||
# square the parameters
|
||||
model.update(tree_map(lambda x: x*x, model.parameters()))
|
||||
|
||||
Args:
|
||||
fn (Callable): The function that processes the leaves of the tree
|
||||
tree (Any): The main python tree that will be iterated upon
|
||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||
|
||||
Returns:
|
||||
A python tree with the new values returned by ``fn``.
|
||||
"""
|
||||
if isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
)
|
||||
elif isinstance(tree, dict):
|
||||
return {
|
||||
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
|
||||
}
|
||||
else:
|
||||
return fn(tree, *rest)
|
||||
|
||||
|
||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
"""Flattens a python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
complexity.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
print(tree_flatten([[[0]]]))
|
||||
# [("0.0.0", 0)]
|
||||
|
||||
print(tree_flatten([[[0]]], ".hello"))
|
||||
# [("hello.0.0.0", 0)]
|
||||
|
||||
.. note::
|
||||
Dictionaries should have keys that are valid python identifiers.
|
||||
|
||||
Args:
|
||||
tree (Any): The python tree to be flattened.
|
||||
prefix (str): A prefix to use for the keys. The first character is
|
||||
always discarded.
|
||||
is_leaf (Callable): An optional callable that returns True if the
|
||||
passed object is considered a leaf or False otherwise.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Any]]: The flat representation of the python tree.
|
||||
"""
|
||||
flat_tree = []
|
||||
|
||||
if is_leaf is None or not is_leaf(tree):
|
||||
if isinstance(tree, (list, tuple)):
|
||||
for i, t in enumerate(tree):
|
||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
||||
return flat_tree
|
||||
if isinstance(tree, dict):
|
||||
for k, t in tree.items():
|
||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
||||
return flat_tree
|
||||
|
||||
return [(prefix[1:], tree)]
|
||||
|
||||
|
||||
def tree_unflatten(tree):
|
||||
"""Recreate a python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
d = tree_unflatten([("hello.world", 42)])
|
||||
print(d)
|
||||
# {"hello": {"world": 42}}
|
||||
|
||||
Args:
|
||||
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
|
||||
For instance as returned by :meth:`tree_flatten`.
|
||||
|
||||
Returns:
|
||||
A python tree.
|
||||
"""
|
||||
if len(tree) == 1 and tree[0][0] == "":
|
||||
return tree[0][1]
|
||||
|
||||
try:
|
||||
int(tree[0][0].split(".", maxsplit=1)[0])
|
||||
is_list = True
|
||||
except ValueError:
|
||||
is_list = False
|
||||
|
||||
# collect children
|
||||
children = {}
|
||||
for key, value in tree:
|
||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||
next_idx = "" if not next_idx else next_idx[0]
|
||||
if current_idx not in children:
|
||||
children[current_idx] = []
|
||||
children[current_idx].append((next_idx, value))
|
||||
|
||||
# recursively map them to the original container
|
||||
if is_list:
|
||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||
l = []
|
||||
for i, k in keys:
|
||||
while i > len(l):
|
||||
l.append({})
|
||||
l.append(tree_unflatten(children[k]))
|
||||
return l
|
||||
else:
|
||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
1071
python/src/array.cpp
Normal file
1071
python/src/array.cpp
Normal file
File diff suppressed because it is too large
Load Diff
42
python/src/device.cpp
Normal file
42
python/src/device.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(py::module_& m) {
|
||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
.export_values()
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Device::DeviceType& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
},
|
||||
py::prepend());
|
||||
|
||||
py::class_<Device>(m, "Device")
|
||||
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_readonly("type", &Device::type)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Device& d) {
|
||||
std::ostringstream os;
|
||||
os << d;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Device& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_device", &default_device);
|
||||
m.def("set_default_device", &set_default_device, "device"_a);
|
||||
}
|
12
python/src/metal.cpp
Normal file
12
python/src/metal.cpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def("is_available", &metal::is_available);
|
||||
}
|
32
python/src/stream.cpp
Normal file
32
python/src/stream.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_stream(py::module_& m) {
|
||||
py::class_<Stream>(m, "Stream")
|
||||
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||
.def_readonly("device", &Stream::device)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Stream& s) {
|
||||
std::ostringstream os;
|
||||
os << s;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Stream& s1, const Stream& s2) {
|
||||
return s1 == s2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_stream", &default_stream, "device"_a);
|
||||
m.def("set_default_stream", &set_default_stream, "stream"_a);
|
||||
m.def("new_stream", &new_stream, "device"_a);
|
||||
}
|
188
python/tests/test_bf16.py
Normal file
188
python/tests/test_bf16.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestBF16(mlx_tests.MLXTestCase):
|
||||
def __test_ops(
|
||||
self,
|
||||
ref_op, # Function that outputs array_like
|
||||
mlx_op, # Function that outputs array_like
|
||||
np_args, # Numpy arguments
|
||||
ref_transform=lambda x: x,
|
||||
mlx_transform=lambda x: mx.array(x),
|
||||
atol=1e-5,
|
||||
):
|
||||
ref_args = map(ref_transform, np_args)
|
||||
mlx_args = map(mlx_transform, np_args)
|
||||
|
||||
r_ref = ref_op(*ref_args)
|
||||
r_mlx = mlx_op(*mlx_args)
|
||||
|
||||
self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))
|
||||
|
||||
def __default_test(
|
||||
self,
|
||||
op,
|
||||
np_args,
|
||||
simple_transform=lambda x: x,
|
||||
atol_np=1e-3,
|
||||
atol_torch=1e-5,
|
||||
np_kwargs=dict(),
|
||||
mlx_kwargs=dict(),
|
||||
torch_kwargs=dict(),
|
||||
torch_op=None,
|
||||
):
|
||||
with self.subTest(reference="numpy"):
|
||||
|
||||
def np_transform(x):
|
||||
x_mx_bf16 = mx.array(x).astype(mx.bfloat16)
|
||||
x_mx_fp32 = x_mx_bf16.astype(mx.float32)
|
||||
return np.asarray(x_mx_fp32)
|
||||
|
||||
def mlx_fn(*args):
|
||||
out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)
|
||||
return np.asarray(out_bf16.astype(mx.float32))
|
||||
|
||||
def np_fn(*args):
|
||||
out_fp32 = getattr(np, op)(*args, **np_kwargs)
|
||||
return np_transform(out_fp32)
|
||||
|
||||
ref_op = np_fn
|
||||
mlx_op = mlx_fn
|
||||
|
||||
ref_transform = lambda x: simple_transform(np_transform(x))
|
||||
mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))
|
||||
|
||||
self.__test_ops(
|
||||
ref_op,
|
||||
mlx_op,
|
||||
np_args,
|
||||
ref_transform=ref_transform,
|
||||
mlx_transform=mlx_transform,
|
||||
atol=atol_np,
|
||||
)
|
||||
|
||||
if has_torch:
|
||||
with self.subTest(reference="torch"):
|
||||
torch_op = op if torch_op is None else torch_op
|
||||
|
||||
def torch_fn(*args):
|
||||
out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)
|
||||
return out_bf16.to(torch.float32).numpy()
|
||||
|
||||
ref_op = torch_fn
|
||||
ref_transform = lambda x: simple_transform(
|
||||
torch.from_numpy(x).to(torch.bfloat16)
|
||||
)
|
||||
self.__test_ops(
|
||||
ref_op,
|
||||
mlx_op,
|
||||
np_args,
|
||||
ref_transform=ref_transform,
|
||||
mlx_transform=mlx_transform,
|
||||
atol=atol_torch,
|
||||
)
|
||||
|
||||
def test_unary_ops(self):
|
||||
x = np.random.rand(18, 28, 38)
|
||||
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
||||
with self.subTest(op=op):
|
||||
np_args = (x.astype(np.float32),)
|
||||
self.__default_test(op, np_args)
|
||||
|
||||
def test_binary_ops(self):
|
||||
x = np.random.rand(18, 28, 38)
|
||||
y = np.random.rand(18, 28, 38)
|
||||
for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
|
||||
with self.subTest(op=op):
|
||||
np_args = (
|
||||
x.astype(np.float32),
|
||||
y.astype(np.float32),
|
||||
)
|
||||
self.__default_test(op, np_args, simple_transform=lambda x: x)
|
||||
self.__default_test(op, np_args, simple_transform=lambda x: x[:1])
|
||||
self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])
|
||||
|
||||
def test_reduction_ops(self):
|
||||
x = np.random.rand(18, 28, 38).astype(np.float32)
|
||||
|
||||
for op in ("min", "max"):
|
||||
with self.subTest(op=op):
|
||||
|
||||
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||
with self.subTest(axes=axes):
|
||||
np_args = (x.astype(np.float32),)
|
||||
self.__default_test(
|
||||
op,
|
||||
np_args,
|
||||
np_kwargs={"axis": axes},
|
||||
mlx_kwargs={"axis": axes},
|
||||
torch_kwargs={"dim": axes},
|
||||
torch_op="a" + op,
|
||||
)
|
||||
|
||||
def test_arg_reduction_ops(self):
|
||||
data = np.random.rand(10, 12, 13).astype(np.float32)
|
||||
x = mx.array(data).astype(mx.bfloat16)
|
||||
data = np.asarray(x.astype(mx.float32))
|
||||
|
||||
for op in ["argmin", "argmax"]:
|
||||
for axis in range(3):
|
||||
for kd in [True, False]:
|
||||
a = getattr(mx, op)(x, axis, kd)
|
||||
b = getattr(np, op)(data, axis, keepdims=kd)
|
||||
a = a.astype(mx.float32)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
|
||||
for op in ["argmin", "argmax"]:
|
||||
a = getattr(mx, op)(x, keepdims=True)
|
||||
b = getattr(np, op)(data, keepdims=True)
|
||||
a = a.astype(mx.float32)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
a = getattr(mx, op)(x)
|
||||
b = getattr(np, op)(data)
|
||||
a = a.astype(mx.float32)
|
||||
self.assertEqual(a.item(), b)
|
||||
|
||||
def test_blas_ops(self):
|
||||
if mx.default_device() != mx.gpu:
|
||||
return
|
||||
|
||||
def test_blas(shape_x, shape_y):
|
||||
np.random.seed(42)
|
||||
with self.subTest(shape_x=shape_x, shape_y=shape_y):
|
||||
x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)
|
||||
y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)
|
||||
|
||||
np_args = (
|
||||
x.astype(np.float32),
|
||||
y.astype(np.float32),
|
||||
)
|
||||
op = "matmul"
|
||||
|
||||
self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)
|
||||
|
||||
for shape_x, shape_y in [
|
||||
[(32, 32), (32, 32)],
|
||||
[(23, 57), (57, 1)],
|
||||
[(1, 3), (3, 128)],
|
||||
[(8, 128, 768), (768, 16)],
|
||||
]:
|
||||
test_blas(shape_x, shape_y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user