Module checks the weight on load_weights (#337)

* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
This commit is contained in:
Awni Hannun 2024-01-02 18:55:42 -08:00 committed by GitHub
parent 0782a4573a
commit dff4a3833f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 581 additions and 360 deletions

View File

@ -0,0 +1,33 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. add toctree option to make autodoc generate the pages
.. autoclass:: {{ objname }}
{% block attributes %}
{% if attributes %}
.. rubric:: Attributes
.. autosummary::
:toctree: .
{% for item in attributes %}
~{{ fullname }}.{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block methods %}
{% if methods %}
.. rubric:: Methods
.. autosummary::
:toctree: .
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@ -170,14 +170,13 @@ In detail:
:meth:`mlx.core.value_and_grad` :meth:`mlx.core.value_and_grad`
.. autosummary:: .. autosummary::
:recursive:
:toctree: _autosummary :toctree: _autosummary
value_and_grad value_and_grad
Module
.. toctree:: .. toctree::
nn/module
nn/layers nn/layers
nn/functions nn/functions
nn/losses nn/losses

View File

@ -0,0 +1,36 @@
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import textwrap import textwrap
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
@ -61,6 +61,7 @@ class Module(dict):
@property @property
def training(self): def training(self):
"""Boolean indicating if the model is in training mode."""
return self._training return self._training
def _extra_repr(self): def _extra_repr(self):
@ -87,15 +88,83 @@ class Module(dict):
def __setattr__(self, key: str, val: Any): def __setattr__(self, key: str, val: Any):
self[key] = val self[key] = val
def load_weights(self, file: str): def load_weights(
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
):
""" """
Load and update the model's weights from a `.npz` file. Update the model's weights from a ``.npz`` or a list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Example:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)
# Load from file
model.load_weights("weights.npz")
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
("bias", mx.zeros((10,))),
]
model.load_weights(weights)
# Missing weight
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
]
# Raises a ValueError exception
model.load_weights(weights)
# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)
""" """
self.update(tree_unflatten(list(mx.load(file).items()))) weights = file_or_weights
if isinstance(weights, str):
weights = list(mx.load(weights).items())
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):
raise ValueError(
"Expected mx.array but received "
f"{type(v_new)} for parameter {k}"
)
if v_new.shape != v.shape:
raise ValueError(
f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))
def save_weights(self, file: str): def save_weights(self, file: str):
""" """
Save the model's weights to a `.npz` file. Save the model's weights to a ``.npz`` file.
""" """
mx.savez(file, **dict(tree_flatten(self.parameters()))) mx.savez(file, **dict(tree_flatten(self.parameters())))
@ -351,23 +420,26 @@ class Module(dict):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not """Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it. computing gradients for it.
This function is idempotent ie freezing a frozen model is a noop. This function is idempotent i.e. freezing a frozen model is a no-op.
For instance to only train the attention parameters from a transformer: Example:
For instance to only train the attention parameters from a Transformer:
model = ... .. code-block:: python
model = nn.Transformer()
model.freeze() model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args: Args:
recurse (bool, optional): If True then freeze the parameters of the recurse (bool, optional): If True then freeze the parameters of the
submodules as well (default: True). submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling module. For instance freeze all biases by calling
``module.freeze(keys="bias")``. ``module.freeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist strict (bool, optional): If set to ``True`` validate that the passed keys exist.
(default: False). Default: ``False``.
""" """
def _freeze_impl(_, m): def _freeze_impl(_, m):
@ -401,21 +473,25 @@ class Module(dict):
This function is idempotent ie unfreezing a model that is not frozen is This function is idempotent ie unfreezing a model that is not frozen is
a noop. a noop.
For instance to only train the biases one can do: Example:
model = ... For instance to only train the biases of a Transformer one can do:
.. code-block:: python
model = nn.Transformer()
model.freeze() model.freeze()
model.unfreeze(keys="bias") model.unfreeze(keys="bias")
Args: Args:
recurse (bool, optional): If True then unfreeze the parameters of the recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well (default: True). submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``. ``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist strict (bool, optional): If set to ``True`` validate that the passed keys exist.
(default: False). Default: ``False``.
""" """
def _unfreeze_impl(_, m): def _unfreeze_impl(_, m):
@ -432,10 +508,25 @@ class Module(dict):
_unfreeze_impl("", self) _unfreeze_impl("", self)
def train(self, mode: bool = True): def train(self, mode: bool = True):
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
:obj:`Dropout` applies a random mask in training mode, but is the
identity in evaluation mode.
Args:
mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``.
"""
def _set_train(_, m): def _set_train(_, m):
m._training = mode m._training = mode
self.apply_to_modules(_set_train) self.apply_to_modules(_set_train)
def eval(self): def eval(self):
"""Set the model to evaluation mode.
See :func:`train`.
"""
self.train(False) self.train(False)

279
python/tests/test_losses.py Normal file
View File

@ -0,0 +1,279 @@
# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
class TestLosses(mlx_tests.MLXTestCase):
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
# Test with reduction 'none'
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
expected_none = mx.array([0.0, 0.831777])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_triplet_loss(self):
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
positives = mx.array([[4, 5, 6], [0, -1, 2]])
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
# Test with reduction 'none'
losses_none = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="none"
)
expected_none = mx.array([0, 2.31662])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
if __name__ == "__main__":
unittest.main()

View File

@ -11,7 +11,112 @@ import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestNN(mlx_tests.MLXTestCase): class TestBase(mlx_tests.MLXTestCase):
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
nn.Linear(10, 1),
mx.sigmoid,
)
children = m.children()
self.assertTrue(isinstance(children, dict))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_load_from_weights(self):
m = nn.Linear(2, 2)
# Too few weights
weights = [("weight", mx.ones((2, 2)))]
with self.assertRaises(ValueError):
m.load_weights(weights)
m.load_weights(weights, strict=False)
self.assertTrue(mx.array_equal(m.weight, weights[0][1]))
# Wrong name
with self.assertRaises(ValueError):
m.load_weights([("weihgt", mx.ones((2, 2)))])
# Ok
m.load_weights([("weihgt", mx.ones((2, 2)))], strict=False)
# Too many weights
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", mx.ones((2,))),
("bias2", mx.ones((2,))),
]
)
# Wrong shape
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", mx.ones((2, 1))),
]
)
# Wrong type
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", 3),
]
)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):
inputs = mx.zeros((10, 4)) inputs = mx.zeros((10, 4))
layer = nn.Identity() layer = nn.Identity()
@ -31,272 +136,6 @@ class TestNN(mlx_tests.MLXTestCase):
outputs = layer(inputs1, inputs2) outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6)) self.assertEqual(tuple(outputs.shape), (10, 6))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
# Test with reduction 'none'
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
expected_none = mx.array([0.0, 0.831777])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_triplet_loss(self):
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
positives = mx.array([[4, 5, 6], [0, -1, 2]])
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
# Test with reduction 'none'
losses_none = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="none"
)
expected_none = mx.array([0, 2.31662])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
# From: jax.nn.gelu(np.array(inputs), approximate=False)
expected = np.array(
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
)
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)
y = nn.gelu(x)
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
def test_group_norm(self): def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32) x = mx.arange(100, dtype=mx.float32)
x = x.reshape(1, 10, 10, 1) x = x.reshape(1, 10, 10, 1)
@ -570,47 +409,24 @@ class TestNN(mlx_tests.MLXTestCase):
y2 = m(x) y2 = m(x)
self.assertTrue(mx.array_equal(y, y2)) self.assertTrue(mx.array_equal(y, y2))
def test_module_utilities(self): def test_gelu(self):
m = nn.Sequential( inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()), # From: jax.nn.gelu(np.array(inputs), approximate=False)
nn.Linear(10, 1), expected = np.array(
mx.sigmoid, [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
) )
children = m.children() out = nn.GELU()(mx.array(inputs))
self.assertTrue(isinstance(children, dict)) self.assertTrue(np.allclose(out, expected))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) # Crudely check the approximations
self.assertEqual(len(leaves), 4) x = mx.arange(-6.0, 6.0, 12 / 100)
self.assertEqual(leaves[0][0], "layers.0.layers.0") y = nn.gelu(x)
self.assertEqual(leaves[1][0], "layers.1.layers.0") y_hat1 = nn.gelu_approx(x)
self.assertEqual(leaves[2][0], "layers.1.layers.1") y_hat2 = nn.gelu_fast_approx(x)
self.assertEqual(leaves[3][0], "layers.2") self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_sin_pe(self): def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01) m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
@ -623,21 +439,6 @@ class TestNN(mlx_tests.MLXTestCase):
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
) )
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_relu(self): def test_relu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x) y = nn.relu(x)
@ -787,24 +588,6 @@ class TestNN(mlx_tests.MLXTestCase):
y = alibi(x.astype(mx.float16)) y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16) self.assertTrue(y.dtype, mx.float16)
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_dropout(self): def test_dropout(self):
x = mx.ones((2, 4)) x = mx.ones((2, 4))
y = nn.Dropout(0.5)(x) y = nn.Dropout(0.5)(x)