update module to check weights on load, also fix docs and reorganize tests

This commit is contained in:
Awni Hannun 2024-01-01 11:43:57 -08:00
parent 99c80a2c8b
commit d2a826b3a4
5 changed files with 513 additions and 344 deletions

View File

@ -0,0 +1,33 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. add toctree option to make autodoc generate the pages
.. autoclass:: {{ objname }}
{% block attributes %}
{% if attributes %}
.. rubric:: Attributes
.. autosummary::
:toctree: .
{% for item in attributes %}
~{{ fullname }}.{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block methods %}
{% if methods %}
.. rubric:: Methods
.. autosummary::
:toctree: .
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@ -170,10 +170,14 @@ In detail:
:meth:`mlx.core.value_and_grad` :meth:`mlx.core.value_and_grad`
.. autosummary:: .. autosummary::
:recursive:
:toctree: _autosummary :toctree: _autosummary
value_and_grad value_and_grad
.. autosummary::
:toctree: _autosummary
:template: module-base-class.rst
Module Module
.. toctree:: .. toctree::

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import textwrap import textwrap
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
@ -87,11 +87,77 @@ class Module(dict):
def __setattr__(self, key: str, val: Any): def __setattr__(self, key: str, val: Any):
self[key] = val self[key] = val
def load_weights(self, file: str): def load_weights(
self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], strict=True
):
""" """
Load and update the model's weights from a `.npz` file. Update the model's weights from a ``.npz`` file or list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Example:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)
# Load from file
model.load_weights("weights.npz")
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
("bias", mx.zeros((10,))),
]
model.load_weights(weights)
# Missing weight
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
]
# Raises a ValueError exception
model.load_weights(weights)
# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)
""" """
self.update(tree_unflatten(list(mx.load(file).items()))) weights = file_or_weights
if isinstance(weights, str):
weights = list(mx.load(weights).items())
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):
raise ValueError(
"Expected mx.array but received "
f"{type(v_new)} for parameter {k}"
)
if v_new.shape != v.shape:
raise ValueError(
f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))
def save_weights(self, file: str): def save_weights(self, file: str):
""" """

279
python/tests/test_losses.py Normal file
View File

@ -0,0 +1,279 @@
# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
class TestLosses(mlx_tests.MLXTestCase):
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
# Test with reduction 'none'
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
expected_none = mx.array([0.0, 0.831777])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_triplet_loss(self):
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
positives = mx.array([[4, 5, 6], [0, -1, 2]])
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
# Test with reduction 'none'
losses_none = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="none"
)
expected_none = mx.array([0, 2.31662])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
if __name__ == "__main__":
unittest.main()

View File

@ -11,7 +11,116 @@ import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestNN(mlx_tests.MLXTestCase): class TestBase(mlx_tests.MLXTestCase):
def test_update(self):
pass
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
nn.Linear(10, 1),
mx.sigmoid,
)
children = m.children()
self.assertTrue(isinstance(children, dict))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_from_weights(self):
m = nn.Linear(2, 2)
# Too few weights
weights = [("weight", mx.ones((2, 2)))]
with self.assertRaises(ValueError):
m.load_weights(weights)
m.load_weights(weights, strict=False)
self.assertTrue(mx.array_equal(m.weight, weights[0][1]))
# Wrong name
with self.assertRaises(ValueError):
m.load_weights([("weihgt", mx.ones((2, 2)))])
# Ok
m.load_weights([("weihgt", mx.ones((2, 2)))], strict=False)
# Too many weights
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", mx.ones((2,))),
("bias2", mx.ones((2,))),
]
)
# Wrong shape
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", mx.ones((2, 1))),
]
)
# Wrong type
with self.assertRaises(ValueError):
m.load_weights(
[
("weight", mx.ones((2, 2))),
("bias", 3),
]
)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):
inputs = mx.zeros((10, 4)) inputs = mx.zeros((10, 4))
layer = nn.Identity() layer = nn.Identity()
@ -31,272 +140,6 @@ class TestNN(mlx_tests.MLXTestCase):
outputs = layer(inputs1, inputs2) outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6)) self.assertEqual(tuple(outputs.shape), (10, 6))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
)
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
)
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
)
target = mx.array([2, 1, 0])
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])
# Expected result
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
self.assertTrue(
mx.array_equal(losses, expected_none),
"Test failed for l1_loss --reduction='none'",
)
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
self.assertTrue(mx.array_equal(losses, expected_sum))
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
self.assertTrue(mx.array_equal(losses, expected_mean))
def test_mse_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.7, 0.1, 0.8, 0.2])
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
expected_mean = mx.mean(expected_none)
expected_sum = mx.sum(expected_none)
# Test with reduction 'none'
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
self.assertTrue(
np.allclose(losses_none, expected_none, 1e-5),
"Test case failed for mse_loss --reduction='none'",
)
# Test with reduction 'mean'
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
self.assertEqual(
losses_mean,
expected_mean,
"Test case failed for mse_loss --reduction='mean'",
)
# Test with reduction 'sum'
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
self.assertEqual(
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
)
def test_smooth_l1_loss(self):
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
targets = mx.array([1.0, 2.0, 0.5, 2.5])
beta = 1.0
# Expected results
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
expected_sum = mx.sum(expected_none)
expected_mean = mx.mean(expected_none)
# Test with reduction 'none'
loss_none = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="none"
)
self.assertTrue(
mx.array_equal(loss_none, expected_none),
"Test case failed for smooth_l1_loss --reduction='none'",
)
# Test with reduction 'sum'
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
self.assertEqual(
loss_sum,
expected_sum,
"Test case failed for smooth_l1_loss --reduction='sum'",
)
# Test with reduction 'mean'
loss_mean = nn.losses.smooth_l1_loss(
predictions, targets, beta, reduction="mean"
)
self.assertEqual(
loss_mean,
expected_mean,
"Test case failed for smooth_l1_loss --reduction='mean'",
)
def test_nll_loss(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
# Test with reduction 'none'
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
# Test with reduction 'none'
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
expected_none = mx.array([0.0, 0.831777])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_triplet_loss(self):
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
positives = mx.array([[4, 5, 6], [0, -1, 2]])
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
# Test with reduction 'none'
losses_none = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="none"
)
expected_none = mx.array([0, 2.31662])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.triplet_loss(
anchors, positives, negatives, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
# From: jax.nn.gelu(np.array(inputs), approximate=False)
expected = np.array(
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
)
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)
y = nn.gelu(x)
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
def test_group_norm(self): def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32) x = mx.arange(100, dtype=mx.float32)
x = x.reshape(1, 10, 10, 1) x = x.reshape(1, 10, 10, 1)
@ -570,47 +413,24 @@ class TestNN(mlx_tests.MLXTestCase):
y2 = m(x) y2 = m(x)
self.assertTrue(mx.array_equal(y, y2)) self.assertTrue(mx.array_equal(y, y2))
def test_module_utilities(self): def test_gelu(self):
m = nn.Sequential( inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()), # From: jax.nn.gelu(np.array(inputs), approximate=False)
nn.Linear(10, 1), expected = np.array(
mx.sigmoid, [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
) )
children = m.children() out = nn.GELU()(mx.array(inputs))
self.assertTrue(isinstance(children, dict)) self.assertTrue(np.allclose(out, expected))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module) # Crudely check the approximations
self.assertEqual(len(leaves), 4) x = mx.arange(-6.0, 6.0, 12 / 100)
self.assertEqual(leaves[0][0], "layers.0.layers.0") y = nn.gelu(x)
self.assertEqual(leaves[1][0], "layers.1.layers.0") y_hat1 = nn.gelu_approx(x)
self.assertEqual(leaves[2][0], "layers.1.layers.1") y_hat2 = nn.gelu_fast_approx(x)
self.assertEqual(leaves[3][0], "layers.2") self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertTrue(leaves[0][1] is m.layers[0].layers[0]) self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_sin_pe(self): def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01) m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
@ -623,21 +443,6 @@ class TestNN(mlx_tests.MLXTestCase):
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
) )
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_relu(self): def test_relu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x) y = nn.relu(x)
@ -787,24 +592,6 @@ class TestNN(mlx_tests.MLXTestCase):
y = alibi(x.astype(mx.float16)) y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16) self.assertTrue(y.dtype, mx.float16)
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_dropout(self): def test_dropout(self):
x = mx.ones((2, 4)) x = mx.ones((2, 4))
y = nn.Dropout(0.5)(x) y = nn.Dropout(0.5)(x)