mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
933 lines
34 KiB
Python
933 lines
34 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx_tests
|
|
import numpy as np
|
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|
|
|
|
|
class TestNN(mlx_tests.MLXTestCase):
|
|
def test_linear(self):
|
|
inputs = mx.zeros((10, 4))
|
|
layer = nn.Linear(input_dims=4, output_dims=8)
|
|
outputs = layer(inputs)
|
|
self.assertEqual(tuple(outputs.shape), (10, 8))
|
|
|
|
def test_cross_entropy(self):
|
|
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
|
targets = mx.array([0, 1])
|
|
|
|
# Test with reduction 'none'
|
|
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
|
|
expected_none = mx.array([0.0, 0.0])
|
|
self.assertTrue(mx.array_equal(losses_none, expected_none))
|
|
|
|
# Test with reduction 'mean'
|
|
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
|
|
expected_mean = mx.mean(expected_none)
|
|
self.assertEqual(losses_mean, expected_mean)
|
|
|
|
# Test with reduction 'sum'
|
|
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
|
|
expected_sum = mx.sum(expected_none)
|
|
self.assertEqual(losses_sum, expected_sum)
|
|
|
|
# Test cases with weights and no label smoothing
|
|
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
|
|
targets = mx.array([0, 1])
|
|
weights = mx.array([1.0, 2.0])
|
|
|
|
# Reduction 'none'
|
|
losses_none = nn.losses.cross_entropy(
|
|
logits,
|
|
targets,
|
|
weights=weights,
|
|
reduction="none",
|
|
)
|
|
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
|
|
self.assertTrue(
|
|
np.allclose(losses_none, expected_none, atol=1e-5),
|
|
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
|
|
)
|
|
|
|
# Reduction 'mean'
|
|
losses_mean = nn.losses.cross_entropy(
|
|
logits,
|
|
targets,
|
|
weights=weights,
|
|
reduction="mean",
|
|
)
|
|
expected_mean = mx.mean(expected_none)
|
|
self.assertTrue(
|
|
np.allclose(losses_mean, expected_mean, atol=1e-5),
|
|
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
|
|
)
|
|
|
|
# Reduction 'sum'
|
|
losses_sum = nn.losses.cross_entropy(
|
|
logits,
|
|
targets,
|
|
weights=weights,
|
|
reduction="sum",
|
|
)
|
|
expected_sum = mx.sum(expected_none)
|
|
self.assertTrue(
|
|
np.allclose(losses_sum, expected_sum, atol=1e-5),
|
|
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
|
|
)
|
|
|
|
# Test case with equal weights and label smoothing > 0
|
|
logits = mx.array(
|
|
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
|
|
)
|
|
target = mx.array([2, 1, 0])
|
|
|
|
losses_none = nn.losses.cross_entropy(
|
|
logits, target, label_smoothing=0.3, reduction="none"
|
|
)
|
|
expected_none = mx.array([1.29693, 1.38617, 1.48176])
|
|
self.assertTrue(
|
|
mx.allclose(expected_none, losses_none),
|
|
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
|
|
)
|
|
|
|
expected_mean = mx.mean(expected_none)
|
|
losses_mean = nn.losses.cross_entropy(
|
|
logits, target, label_smoothing=0.3, reduction="mean"
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(losses_mean, expected_mean),
|
|
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
|
|
)
|
|
|
|
expected_sum = mx.sum(expected_none)
|
|
losses_sum = nn.losses.cross_entropy(
|
|
logits, target, label_smoothing=0.3, reduction="sum"
|
|
)
|
|
self.assertTrue(
|
|
mx.allclose(losses_sum, expected_sum),
|
|
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
|
|
)
|
|
|
|
def test_l1_loss(self):
|
|
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
|
targets = mx.array([0.5, 0.2, 0.9, 0.0])
|
|
|
|
# Expected result
|
|
expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32)
|
|
expected_sum = mx.sum(expected_none)
|
|
expected_mean = mx.mean(expected_none)
|
|
|
|
losses = nn.losses.l1_loss(predictions, targets, reduction="none")
|
|
self.assertTrue(
|
|
mx.array_equal(losses, expected_none),
|
|
"Test failed for l1_loss --reduction='none'",
|
|
)
|
|
|
|
losses = nn.losses.l1_loss(predictions, targets, reduction="sum")
|
|
self.assertTrue(mx.array_equal(losses, expected_sum))
|
|
|
|
losses = nn.losses.l1_loss(predictions, targets, reduction="mean")
|
|
self.assertTrue(mx.array_equal(losses, expected_mean))
|
|
|
|
def test_mse_loss(self):
|
|
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
|
|
targets = mx.array([0.7, 0.1, 0.8, 0.2])
|
|
|
|
expected_none = mx.array([0.04, 0.01, 0.01, 0.04])
|
|
expected_mean = mx.mean(expected_none)
|
|
expected_sum = mx.sum(expected_none)
|
|
|
|
# Test with reduction 'none'
|
|
losses_none = nn.losses.mse_loss(predictions, targets, reduction="none")
|
|
self.assertTrue(
|
|
np.allclose(losses_none, expected_none, 1e-5),
|
|
"Test case failed for mse_loss --reduction='none'",
|
|
)
|
|
|
|
# Test with reduction 'mean'
|
|
losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean")
|
|
self.assertEqual(
|
|
losses_mean,
|
|
expected_mean,
|
|
"Test case failed for mse_loss --reduction='mean'",
|
|
)
|
|
|
|
# Test with reduction 'sum'
|
|
losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum")
|
|
self.assertEqual(
|
|
losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'"
|
|
)
|
|
|
|
def test_smooth_l1_loss(self):
|
|
predictions = mx.array([1.5, 2.5, 0.5, 3.5])
|
|
targets = mx.array([1.0, 2.0, 0.5, 2.5])
|
|
beta = 1.0
|
|
|
|
# Expected results
|
|
expected_none = mx.array([0.125, 0.125, 0.0, 0.5])
|
|
expected_sum = mx.sum(expected_none)
|
|
expected_mean = mx.mean(expected_none)
|
|
|
|
# Test with reduction 'none'
|
|
loss_none = nn.losses.smooth_l1_loss(
|
|
predictions, targets, beta, reduction="none"
|
|
)
|
|
self.assertTrue(
|
|
mx.array_equal(loss_none, expected_none),
|
|
"Test case failed for smooth_l1_loss --reduction='none'",
|
|
)
|
|
|
|
# Test with reduction 'sum'
|
|
loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum")
|
|
self.assertEqual(
|
|
loss_sum,
|
|
expected_sum,
|
|
"Test case failed for smooth_l1_loss --reduction='sum'",
|
|
)
|
|
|
|
# Test with reduction 'mean'
|
|
loss_mean = nn.losses.smooth_l1_loss(
|
|
predictions, targets, beta, reduction="mean"
|
|
)
|
|
self.assertEqual(
|
|
loss_mean,
|
|
expected_mean,
|
|
"Test case failed for smooth_l1_loss --reduction='mean'",
|
|
)
|
|
|
|
def test_nll_loss(self):
|
|
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
|
targets = mx.array([0, 1])
|
|
|
|
# Test with reduction 'none'
|
|
losses_none = nn.losses.nll_loss(logits, targets, reduction="none")
|
|
expected_none = mx.array([0.0, 0.0])
|
|
self.assertTrue(mx.array_equal(losses_none, expected_none))
|
|
|
|
# Test with reduction 'mean'
|
|
losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean")
|
|
expected_mean = mx.mean(expected_none)
|
|
self.assertEqual(losses_mean, expected_mean)
|
|
|
|
# Test with reduction 'sum'
|
|
losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum")
|
|
expected_sum = mx.sum(expected_none)
|
|
self.assertEqual(losses_sum, expected_sum)
|
|
|
|
def test_kl_div_loss(self):
|
|
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
|
|
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
|
|
|
|
# Test with reduction 'none'
|
|
losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none")
|
|
expected_none = mx.array([0.0, 0.831777])
|
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
|
|
|
# Test with reduction 'mean'
|
|
losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean")
|
|
expected_mean = mx.mean(expected_none)
|
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
|
|
|
# Test with reduction 'sum'
|
|
losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum")
|
|
expected_sum = mx.sum(expected_none)
|
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
|
|
|
def test_triplet_loss(self):
|
|
anchors = mx.array([[1, 2, 3], [1, 2, 3]])
|
|
positives = mx.array([[4, 5, 6], [0, -1, 2]])
|
|
negatives = mx.array([[7, 8, 9], [3, 2, 3]])
|
|
|
|
# Test with reduction 'none'
|
|
losses_none = nn.losses.triplet_loss(
|
|
anchors, positives, negatives, reduction="none"
|
|
)
|
|
expected_none = mx.array([0, 2.31662])
|
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
|
|
|
# Test with reduction 'mean'
|
|
losses_mean = nn.losses.triplet_loss(
|
|
anchors, positives, negatives, reduction="mean"
|
|
)
|
|
expected_mean = mx.mean(expected_none)
|
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
|
|
|
# Test with reduction 'sum'
|
|
losses_sum = nn.losses.triplet_loss(
|
|
anchors, positives, negatives, reduction="sum"
|
|
)
|
|
expected_sum = mx.sum(expected_none)
|
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
|
|
|
def test_gelu(self):
|
|
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
|
|
|
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
|
expected = np.array(
|
|
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
|
)
|
|
|
|
out = nn.GELU()(mx.array(inputs))
|
|
self.assertTrue(np.allclose(out, expected))
|
|
|
|
# Crudely check the approximations
|
|
x = mx.arange(-6.0, 6.0, 12 / 100)
|
|
y = nn.gelu(x)
|
|
y_hat1 = nn.gelu_approx(x)
|
|
y_hat2 = nn.gelu_fast_approx(x)
|
|
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
|
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
|
|
|
def test_group_norm(self):
|
|
x = mx.arange(100, dtype=mx.float32)
|
|
x = x.reshape(1, 10, 10, 1)
|
|
x = mx.broadcast_to(x, (2, 10, 10, 4))
|
|
x = mx.concatenate([x, 0.5 * x], axis=-1)
|
|
|
|
# Group norm in groups last mode
|
|
g = nn.GroupNorm(2, 8)
|
|
y = g(x)
|
|
means = y.reshape(2, -1, 2).mean(axis=1)
|
|
var = y.reshape(2, -1, 2).var(axis=1)
|
|
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
|
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
|
g.weight = g.weight * 2
|
|
g.bias = g.bias + 3
|
|
y = g(x)
|
|
means = y.reshape(2, -1, 2).mean(axis=1)
|
|
var = y.reshape(2, -1, 2).var(axis=1)
|
|
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
|
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
|
|
|
# Group norm in groups first mode
|
|
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
|
|
y = g(x)
|
|
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
|
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
|
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
|
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
|
g.weight = g.weight * 2
|
|
g.bias = g.bias + 3
|
|
y = g(x)
|
|
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
|
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
|
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
|
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
|
|
|
def test_instance_norm(self):
|
|
# Test InstanceNorm1d
|
|
x = mx.array(
|
|
[
|
|
[
|
|
[-0.0119524, -0.500331, 1.12958, 1.39955],
|
|
[1.1263, 0.517899, -0.21413, 0.891329],
|
|
[2.02223, -1.21143, -2.48738, 1.63289],
|
|
],
|
|
[
|
|
[0.241417, -1.42512, 2.739, -1.23175],
|
|
[-0.619157, 0.970817, -1.2506, 0.32756],
|
|
[-0.77484, -1.31352, 1.56844, 1.13969],
|
|
],
|
|
]
|
|
)
|
|
inorm = nn.InstanceNorm(num_features=3)
|
|
y = inorm(x)
|
|
expected_y = [
|
|
[
|
|
[-0.657082, -1.27879, 0.796097, 1.13978],
|
|
[1.07593, -0.123075, -1.56572, 0.61286],
|
|
[1.0712, -0.632503, -1.30476, 0.866066],
|
|
],
|
|
[
|
|
[0.0964433, -0.904773, 1.59693, -0.788599],
|
|
[-0.557908, 1.30444, -1.29751, 0.550987],
|
|
[-0.759886, -1.20013, 1.15521, 0.804804],
|
|
],
|
|
]
|
|
self.assertTrue(x.shape == y.shape)
|
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
|
# Test InstanceNorm2d
|
|
x = mx.array(
|
|
[
|
|
[
|
|
[
|
|
[-0.458824, -0.447996, 0.0486988],
|
|
[1.13049, 0.301795, -2.23876],
|
|
[0.0986325, -1.25257, -0.329399],
|
|
],
|
|
[
|
|
[0.483254, -0.176577, -0.0611224],
|
|
[0.345315, 0.99207, -0.758631],
|
|
[-1.82973, 0.154442, -0.319107],
|
|
],
|
|
[
|
|
[-0.58611, -0.622545, 1.8845],
|
|
[-0.926389, -0.184927, -1.12639],
|
|
[-0.241765, -0.556204, 0.830584],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[1.04407, 0.0800776, 0.782321],
|
|
[0.671423, -0.110299, 0.159905],
|
|
[0.810252, 0.182597, -0.0621687],
|
|
],
|
|
[
|
|
[0.073752, 1.2513, -0.444367],
|
|
[-1.21689, -1.42248, 0.516452],
|
|
[1.50456, 0.0576239, 0.184253],
|
|
],
|
|
[
|
|
[0.407081, 1.20627, 0.563132],
|
|
[-1.88979, 1.17838, -0.539121],
|
|
[1.08659, 0.973883, 0.784216],
|
|
],
|
|
],
|
|
]
|
|
)
|
|
inorm = nn.InstanceNorm(num_features=3)
|
|
y = inorm(x)
|
|
expected_y = [
|
|
[
|
|
[
|
|
[-0.120422, -0.108465, 0.440008],
|
|
[1.63457, 0.719488, -2.08591],
|
|
[0.495147, -0.996913, 0.0224944],
|
|
],
|
|
[
|
|
[0.801504, -0.0608616, 0.0900314],
|
|
[0.621224, 1.4665, -0.821576],
|
|
[-2.22144, 0.371763, -0.247141],
|
|
],
|
|
[
|
|
[-0.463984, -0.504602, 2.29032],
|
|
[-0.843336, -0.0167355, -1.0663],
|
|
[-0.0800997, -0.430644, 1.11538],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[1.59749, -0.776381, 0.95293],
|
|
[0.679838, -1.24519, -0.579803],
|
|
[1.02171, -0.523923, -1.12667],
|
|
],
|
|
[
|
|
[0.0190289, 1.28291, -0.537076],
|
|
[-1.36624, -1.5869, 0.494185],
|
|
[1.55474, 0.00171834, 0.137631],
|
|
],
|
|
[
|
|
[-0.012331, 0.817234, 0.149652],
|
|
[-2.39651, 0.78829, -0.994498],
|
|
[0.693007, 0.576016, 0.37914],
|
|
],
|
|
],
|
|
]
|
|
self.assertTrue(x.shape == y.shape)
|
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
|
# Test InstanceNorm3d
|
|
x = mx.array(
|
|
[
|
|
[
|
|
[
|
|
[[0.777621, -2.1722], [-1.41317, 0.284446]],
|
|
[[0.11, -0.837743], [-2.40205, 0.336682]],
|
|
[[0.789185, -1.42998], [-0.459489, 0.0298199]],
|
|
],
|
|
[
|
|
[[0.528145, 0.128192], [0.476288, -0.649858]],
|
|
[[-0.12431, 1.93502], [-1.25873, -0.261986]],
|
|
[[-1.63747, -1.73247], [-2.15559, 0.10275]],
|
|
],
|
|
[
|
|
[[-1.56133, 0.153862], [-1.20411, 0.152112]],
|
|
[[1.18768, 0.00236324], [-2.04243, 1.54289]],
|
|
[[0.67917, -0.402572], [-0.249959, -0.821897]],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[[-2.12354, 0.317797], [-0.146628, 0.0329215]],
|
|
[[-1.55784, 2.41031], [0.226341, 0.265387]],
|
|
[[0.990317, 0.475161], [-1.37804, -0.501041]],
|
|
],
|
|
[
|
|
[[0.643973, -0.682916], [-0.987925, 1.54086]],
|
|
[[0.71179, -0.290786], [0.057712, -0.742304]],
|
|
[[-0.399875, -1.10479], [1.40097, 0.0723374]],
|
|
],
|
|
[
|
|
[[0.72391, 0.016364], [0.573199, 0.213092]],
|
|
[[-0.0678402, 0.00449439], [-1.58342, 1.28133]],
|
|
[[-0.357647, -1.07389], [0.141618, -0.386141]],
|
|
],
|
|
],
|
|
]
|
|
)
|
|
inorm = nn.InstanceNorm(num_features=3)
|
|
y = inorm(x)
|
|
expected_y = [
|
|
[
|
|
[
|
|
[[1.23593, -1.54739], [-0.831204, 0.770588]],
|
|
[[0.605988, -0.288258], [-1.76427, 0.819875]],
|
|
[[1.24684, -0.847068], [0.0686449, 0.530334]],
|
|
],
|
|
[
|
|
[[0.821849, 0.462867], [0.775304, -0.23548]],
|
|
[[0.236231, 2.0846], [-0.78198, 0.112659]],
|
|
[[-1.12192, -1.20719], [-1.58697, 0.440032]],
|
|
],
|
|
[
|
|
[[-1.30944, 0.357126], [-0.962338, 0.355425]],
|
|
[[1.36163, 0.209922], [-1.77689, 1.70677]],
|
|
[[0.867539, -0.183531], [-0.0352458, -0.590967]],
|
|
],
|
|
],
|
|
[
|
|
[
|
|
[[-1.75315, 0.343736], [-0.0551618, 0.0990544]],
|
|
[[-1.26726, 2.14101], [0.265184, 0.298721]],
|
|
[[0.921369, 0.478897], [-1.11283, -0.35957]],
|
|
],
|
|
[
|
|
[[0.733967, -0.822472], [-1.18025, 1.78602]],
|
|
[[0.813517, -0.362504], [0.0462839, -0.892134]],
|
|
[[-0.490465, -1.31732], [1.62192, 0.0634394]],
|
|
],
|
|
[
|
|
[[1.04349, 0.080661], [0.838402, 0.348368]],
|
|
[[-0.033924, 0.0645089], [-2.09632, 1.80203]],
|
|
[[-0.428293, -1.40296], [0.251107, -0.467067]],
|
|
],
|
|
],
|
|
]
|
|
self.assertTrue(x.shape == y.shape)
|
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
|
|
|
def test_batch_norm(self):
|
|
mx.random.seed(42)
|
|
x = mx.random.normal((5, 4), dtype=mx.float32)
|
|
|
|
# Batch norm
|
|
bn = nn.BatchNorm(num_features=4, affine=True)
|
|
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
|
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
|
y = bn(x)
|
|
expected_y = mx.array(
|
|
[
|
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
|
],
|
|
)
|
|
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
|
|
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
|
self.assertTrue(x.shape == y.shape)
|
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
|
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
|
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
|
|
|
# test eval mode
|
|
bn.eval()
|
|
y = bn(x)
|
|
expected_y = mx.array(
|
|
[
|
|
[-0.15984, 1.73159, -1.25456, 1.57891],
|
|
[-0.872193, -1.4281, -0.414439, -0.228678],
|
|
[0.602743, -0.30566, -0.554687, 0.139639],
|
|
[0.252199, 0.29066, -0.599572, -0.0512532],
|
|
[0.594096, -0.0334829, 2.11359, -0.151081],
|
|
]
|
|
)
|
|
|
|
self.assertTrue(x.shape == y.shape)
|
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
|
|
|
# test_no_affine
|
|
bn = nn.BatchNorm(num_features=4, affine=False)
|
|
y = bn(x)
|
|
expected_y = mx.array(
|
|
[
|
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
|
]
|
|
)
|
|
self.assertTrue(x.shape == y.shape)
|
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
|
|
|
# test with 3D input
|
|
mx.random.seed(42)
|
|
N = 2
|
|
L = 4
|
|
C = 5
|
|
x = mx.random.normal((N, L, C), dtype=mx.float32)
|
|
|
|
# Batch norm
|
|
bn = nn.BatchNorm(num_features=C, affine=True)
|
|
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
|
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
|
y = bn(x)
|
|
self.assertTrue(x.shape == y.shape)
|
|
expected_y = mx.array(
|
|
[
|
|
[
|
|
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
|
|
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
|
|
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
|
|
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
|
|
],
|
|
[
|
|
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
|
|
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
|
|
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
|
|
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
|
|
],
|
|
]
|
|
)
|
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
|
expected_mean = mx.array(
|
|
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
|
)
|
|
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
|
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
|
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
|
|
|
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
|
with self.assertRaises(ValueError):
|
|
y = bn(x)
|
|
|
|
def test_batch_norm_stats(self):
|
|
batch_size = 2
|
|
num_features = 4
|
|
h = 3
|
|
w = 3
|
|
momentum = 0.1
|
|
|
|
batch_norm = nn.BatchNorm(num_features)
|
|
|
|
batch_norm.train()
|
|
running_mean = np.array(batch_norm._running_mean)
|
|
running_var = np.array(batch_norm._running_var)
|
|
|
|
data = mx.random.normal((batch_size, num_features))
|
|
|
|
normalized_data = batch_norm(data)
|
|
np_data = np.array(data)
|
|
means = np.mean(np_data, axis=0)
|
|
variances = np.var(np_data, axis=0)
|
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
|
running_var = (1 - momentum) * running_var + momentum * variances
|
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
|
|
|
batch_norm = nn.BatchNorm(num_features)
|
|
|
|
batch_norm.train()
|
|
running_mean = np.array(batch_norm._running_mean)
|
|
running_var = np.array(batch_norm._running_var)
|
|
data = mx.random.normal((batch_size, h, w, num_features))
|
|
|
|
normalized_data = batch_norm(data)
|
|
np_data = np.array(data)
|
|
means = np.mean(np_data, axis=(0, 1, 2))
|
|
variances = np.var(np_data, axis=(0, 1, 2))
|
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
|
running_var = (1 - momentum) * running_var + momentum * variances
|
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
|
|
|
def test_conv1d(self):
|
|
N = 5
|
|
L = 12
|
|
ks = 3
|
|
C_in = 2
|
|
C_out = 4
|
|
x = mx.ones((N, L, C_in))
|
|
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
|
|
c.weight = mx.ones_like(c.weight)
|
|
y = c(x)
|
|
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
|
|
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
|
|
|
|
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
|
|
y = c(x)
|
|
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
|
|
self.assertTrue("bias" in c.parameters())
|
|
|
|
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
|
self.assertTrue("bias" not in c.parameters())
|
|
|
|
def test_conv2d(self):
|
|
x = mx.ones((4, 8, 8, 3))
|
|
c = nn.Conv2d(3, 1, 8)
|
|
y = c(x)
|
|
self.assertEqual(y.shape, [4, 1, 1, 1])
|
|
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
|
|
y = c(x)
|
|
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
|
|
|
|
# 3x3 conv no padding stride 1
|
|
c = nn.Conv2d(3, 8, 3)
|
|
y = c(x)
|
|
self.assertEqual(y.shape, [4, 6, 6, 8])
|
|
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
|
|
|
# 3x3 conv padding 1 stride 1
|
|
c = nn.Conv2d(3, 8, 3, padding=1)
|
|
y = c(x)
|
|
self.assertEqual(y.shape, [4, 8, 8, 8])
|
|
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
|
self.assertLess(
|
|
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
|
|
1e-4,
|
|
)
|
|
self.assertLess(
|
|
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
|
|
1e-4,
|
|
)
|
|
self.assertLess(
|
|
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
|
|
1e-4,
|
|
)
|
|
self.assertLess(
|
|
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
|
|
1e-4,
|
|
)
|
|
|
|
# 3x3 conv no padding stride 2
|
|
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
|
|
y = c(x)
|
|
self.assertEqual(y.shape, [4, 3, 3, 8])
|
|
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
|
|
|
def test_sequential(self):
|
|
x = mx.ones((10, 2))
|
|
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
|
|
y = m(x)
|
|
self.assertEqual(y.shape, [10, 1])
|
|
params = m.parameters()
|
|
self.assertTrue("layers" in params)
|
|
self.assertEqual(len(params["layers"]), 3)
|
|
self.assertTrue("weight" in params["layers"][0])
|
|
self.assertEqual(len(params["layers"][1]), 0)
|
|
self.assertTrue("weight" in params["layers"][2])
|
|
|
|
m.layers[1] = nn.relu
|
|
y2 = m(x)
|
|
self.assertTrue(mx.array_equal(y, y2))
|
|
|
|
def test_module_utilities(self):
|
|
m = nn.Sequential(
|
|
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
|
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
|
nn.Linear(10, 1),
|
|
mx.sigmoid,
|
|
)
|
|
|
|
children = m.children()
|
|
self.assertTrue(isinstance(children, dict))
|
|
self.assertEqual(len(children), 1)
|
|
self.assertTrue(isinstance(children["layers"], list))
|
|
self.assertEqual(len(children["layers"]), 4)
|
|
self.assertEqual(children["layers"][3], {})
|
|
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
|
self.assertEqual(len(flat_children), 3)
|
|
|
|
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
|
self.assertEqual(len(leaves), 4)
|
|
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
|
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
|
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
|
self.assertEqual(leaves[3][0], "layers.2")
|
|
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
|
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
|
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
|
self.assertTrue(leaves[3][1] is m.layers[2])
|
|
|
|
m.eval()
|
|
|
|
def assert_not_training(k, m):
|
|
self.assertFalse(m.training)
|
|
|
|
m.apply_to_modules(assert_not_training)
|
|
|
|
m.train()
|
|
|
|
def assert_training(k, m):
|
|
self.assertTrue(m.training)
|
|
|
|
m.apply_to_modules(assert_training)
|
|
|
|
def test_sin_pe(self):
|
|
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
|
x = mx.arange(10)
|
|
y = m(x)
|
|
|
|
self.assertEqual(y.shape, [10, 16])
|
|
similarities = y @ y.T
|
|
self.assertLess(
|
|
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
|
)
|
|
|
|
def test_io(self):
|
|
def make_model():
|
|
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
|
|
|
m = make_model()
|
|
tdir = tempfile.TemporaryDirectory()
|
|
file = os.path.join(tdir.name, "model.npz")
|
|
m.save_weights(file)
|
|
m_load = make_model()
|
|
m_load.load_weights(file)
|
|
tdir.cleanup()
|
|
|
|
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
|
self.assertTrue(all(tree_flatten(eq_tree)))
|
|
|
|
def test_relu(self):
|
|
x = mx.array([1.0, -1.0, 0.0])
|
|
y = nn.relu(x)
|
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_leaky_relu(self):
|
|
x = mx.array([1.0, -1.0, 0.0])
|
|
y = nn.leaky_relu(x)
|
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
y = nn.LeakyReLU(negative_slope=0.1)(x)
|
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_elu(self):
|
|
x = mx.array([1.0, -1.0, 0.0])
|
|
y = nn.elu(x)
|
|
epsilon = 1e-4
|
|
expected_y = mx.array([1.0, -0.6321, 0.0])
|
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
y = nn.ELU(alpha=1.1)(x)
|
|
epsilon = 1e-4
|
|
expected_y = mx.array([1.0, -0.6953, 0.0])
|
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_relu6(self):
|
|
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
|
|
y = nn.relu6(x)
|
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
|
|
self.assertEqual(y.shape, [5])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_softplus(self):
|
|
x = mx.array([1.0, -1.0, 0.0])
|
|
y = nn.softplus(x)
|
|
epsilon = 1e-4
|
|
expected_y = mx.array([1.3133, 0.3133, 0.6931])
|
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_celu(self):
|
|
x = mx.array([1.0, -1.0, 0.0])
|
|
y = nn.celu(x)
|
|
epsilon = 1e-4
|
|
expected_y = mx.array([1.0, -0.6321, 0.0])
|
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
y = nn.CELU(alpha=1.1)(x)
|
|
expected_y = mx.array([1.0, -0.6568, 0.0])
|
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_log_sigmoid(self):
|
|
x = mx.array([1.0, -1.0, 0.0])
|
|
y = nn.log_sigmoid(x)
|
|
epsilon = 1e-4
|
|
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
|
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
|
self.assertEqual(y.shape, [3])
|
|
self.assertEqual(y.dtype, mx.float32)
|
|
|
|
def test_prelu(self):
|
|
self.assertEqualArray(
|
|
nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
|
mx.array([1.0, -0.25, 0.0, 0.5]),
|
|
)
|
|
|
|
def test_mish(self):
|
|
self.assertEqualArray(
|
|
nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
|
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
|
)
|
|
|
|
def test_rope(self):
|
|
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
|
rope = nn.RoPE(4, **kwargs)
|
|
shape = (1, 3, 4)
|
|
x = mx.random.uniform(shape=shape)
|
|
y = rope(x)
|
|
self.assertTrue(y.shape, shape)
|
|
self.assertTrue(y.dtype, mx.float32)
|
|
|
|
y = rope(x, offset=3)
|
|
self.assertTrue(y.shape, shape)
|
|
|
|
y = rope(x.astype(mx.float16))
|
|
self.assertTrue(y.dtype, mx.float16)
|
|
|
|
def test_alibi(self):
|
|
alibi = nn.ALiBi()
|
|
shape = [1, 8, 20, 20]
|
|
x = mx.random.uniform(shape=shape)
|
|
y = alibi(x)
|
|
self.assertTrue(y.shape, shape)
|
|
self.assertTrue(y.dtype, mx.float32)
|
|
|
|
y = alibi(x.astype(mx.float16))
|
|
self.assertTrue(y.dtype, mx.float16)
|
|
|
|
def test_hinge_loss(self):
|
|
inputs = mx.ones((2, 4))
|
|
targets = mx.zeros((2, 4))
|
|
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
|
|
self.assertEqual(loss, 1.0)
|
|
|
|
def test_huber_loss(self):
|
|
inputs = mx.ones((2, 4))
|
|
targets = mx.zeros((2, 4))
|
|
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
|
|
self.assertEqual(loss, 0.5)
|
|
|
|
def test_log_cosh_loss(self):
|
|
inputs = mx.ones((2, 4))
|
|
targets = mx.zeros((2, 4))
|
|
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
|
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|