Module checks the weight on load_weights (#337)

* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
This commit is contained in:
Awni Hannun 2024-01-02 18:55:42 -08:00 committed by GitHub
parent 0782a4573a
commit dff4a3833f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 581 additions and 360 deletions

View File

@ -0,0 +1,33 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. add toctree option to make autodoc generate the pages
.. autoclass:: {{ objname }}
{% block attributes %}
{% if attributes %}
.. rubric:: Attributes
.. autosummary::
:toctree: .
{% for item in attributes %}
~{{ fullname }}.{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block methods %}
{% if methods %}
.. rubric:: Methods
.. autosummary::
:toctree: .
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@ -170,14 +170,13 @@ In detail:
:meth:`mlx.core.value_and_grad`
.. autosummary::
:recursive:
:toctree: _autosummary
value_and_grad
Module
.. toctree::
nn/module
nn/layers
nn/functions
nn/losses

View File

@ -0,0 +1,36 @@
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import textwrap
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
@ -61,6 +61,7 @@ class Module(dict):
@property
def training(self):
"""Boolean indicating if the model is in training mode."""
return self._training
def _extra_repr(self):
@ -87,15 +88,83 @@ class Module(dict):
def __setattr__(self, key: str, val: Any):
self[key] = val
def load_weights(self, file: str):
def load_weights(
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
):
"""
Load and update the model's weights from a `.npz` file.
Update the model's weights from a ``.npz`` or a list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Example:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)
# Load from file
model.load_weights("weights.npz")
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
("bias", mx.zeros((10,))),
]
model.load_weights(weights)
# Missing weight
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
]
# Raises a ValueError exception
model.load_weights(weights)
# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)
"""
self.update(tree_unflatten(list(mx.load(file).items())))
weights = file_or_weights
if isinstance(weights, str):
weights = list(mx.load(weights).items())
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):
raise ValueError(
"Expected mx.array but received "
f"{type(v_new)} for parameter {k}"
)
if v_new.shape != v.shape:
raise ValueError(
f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))
def save_weights(self, file: str):
"""
Save the model's weights to a `.npz` file.
Save the model's weights to a ``.npz`` file.
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
@ -351,23 +420,26 @@ class Module(dict):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent ie freezing a frozen model is a noop.
This function is idempotent i.e. freezing a frozen model is a no-op.
For instance to only train the attention parameters from a transformer:
Example:
For instance to only train the attention parameters from a Transformer:
model = ...
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
.. code-block:: python
model = nn.Transformer()
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args:
recurse (bool, optional): If True then freeze the parameters of the
submodules as well (default: True).
submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
``module.freeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
"""
def _freeze_impl(_, m):
@ -401,21 +473,25 @@ class Module(dict):
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
For instance to only train the biases one can do:
Example:
model = ...
model.freeze()
model.unfreeze(keys="bias")
For instance to only train the biases of a Transformer one can do:
.. code-block:: python
model = nn.Transformer()
model.freeze()
model.unfreeze(keys="bias")
Args:
recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well (default: True).
submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to True validate that the passed keys exist
(default: False).
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
"""
def _unfreeze_impl(_, m):
@ -432,10 +508,25 @@ class Module(dict):
_unfreeze_impl("", self)
def train(self, mode: bool = True):
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
:obj:`Dropout` applies a random mask in training mode, but is the
identity in evaluation mode.
Args:
mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``.
"""
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
def eval(self):
"""Set the model to evaluation mode.
See :func:`train`.
"""
self.train(False)

279
python/tests/test_losses.py Normal file
View File

@ -0,0 +1,279 @@
# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
class TestLosses(mlx_tests.MLXTestCase):
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
# Test with reduction 'none'
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
expected_none = mx.array([0.0, 0.831777])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_triplet_loss(self):
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
positives = mx.array([[4, 5, 6], [0, -1, 2]])
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
# Test with reduction 'none'
losses_none = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="none"
)
expected_none = mx.array([0, 2.31662])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
if __name__ == "__main__":
unittest.main()

View File

@ -11,7 +11,112 @@ import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestNN(mlx_tests.MLXTestCase):
class TestBase(mlx_tests.MLXTestCase):
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
nn.Linear(10, 1),
mx.sigmoid,
)
children = m.children()
self.assertTrue(isinstance(children, dict))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_load_from_weights(self):
m = nn.Linear(2, 2)
# Too few weights
weights = [("weight", mx.ones((2, 2)))]
with self.assertRaises(ValueError):
m.load_weights(weights)
m.load_weights(weights, strict=False)
self.assertTrue(mx.array_equal(m.weight, weights[0][1]))
# Wrong name
with self.assertRaises(ValueError):
m.load_weights([("weihgt", mx.ones((2, 2)))])
# Ok
m.load_weights([("weihgt", mx.ones((2, 2)))], strict=False)
# Too many weights
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", mx.ones((2,))),
("bias2", mx.ones((2,))),
]
)
# Wrong shape
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", mx.ones((2, 1))),
]
)
# Wrong type
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", 3),
]
)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):
inputs = mx.zeros((10, 4))
layer = nn.Identity()
@ -31,272 +136,6 @@ class TestNN(mlx_tests.MLXTestCase):
outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
# Test with reduction 'none'
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
expected_none = mx.array([0.0, 0.831777])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_triplet_loss(self):
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
positives = mx.array([[4, 5, 6], [0, -1, 2]])
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
# Test with reduction 'none'
losses_none = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="none"
)
expected_none = mx.array([0, 2.31662])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
# From: jax.nn.gelu(np.array(inputs), approximate=False)
expected = np.array(
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
)
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)
y = nn.gelu(x)
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32)
x = x.reshape(1, 10, 10, 1)
@ -570,47 +409,24 @@ class TestNN(mlx_tests.MLXTestCase):
y2 = m(x)
self.assertTrue(mx.array_equal(y, y2))
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
nn.Linear(10, 1),
mx.sigmoid,
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
# From: jax.nn.gelu(np.array(inputs), approximate=False)
expected = np.array(
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
)
children = m.children()
self.assertTrue(isinstance(children, dict))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)
y = nn.gelu(x)
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
@ -623,21 +439,6 @@ class TestNN(mlx_tests.MLXTestCase):
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
)
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x)
@ -787,24 +588,6 @@ class TestNN(mlx_tests.MLXTestCase):
y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_dropout(self):
x = mx.ones((2, 4))
y = nn.Dropout(0.5)(x)