This commit is contained in:
Jyun1998 2024-01-02 02:55:56 +09:00
parent 80c4630d26
commit 353a4cb1cd

View File

@ -6,7 +6,6 @@ import unittest
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx_tests import mlx_tests
import numpy as np import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
@ -197,12 +196,18 @@ class TestNN(mlx_tests.MLXTestCase):
delta = 1.0 delta = 1.0
# Test with reduction 'none' # Test with reduction 'none'
losses_none = nn.losses.huber_loss(predictions, targets, delta, reduction="none") losses_none = nn.losses.huber_loss(
expected_none = mx.array([0.125, 0.125, 0.125, 0.125]) # Example expected values predictions, targets, delta, reduction="none"
)
expected_none = mx.array(
[0.125, 0.125, 0.125, 0.125]
) # Example expected values
self.assertTrue(mx.allclose(losses_none, expected_none)) self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean' # Test with reduction 'mean'
losses_mean = nn.losses.huber_loss(predictions, targets, delta, reduction="mean") losses_mean = nn.losses.huber_loss(
predictions, targets, delta, reduction="mean"
)
expected_mean = mx.mean(expected_none) expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean)) self.assertTrue(mx.allclose(losses_mean, expected_mean))
@ -237,17 +242,28 @@ class TestNN(mlx_tests.MLXTestCase):
gamma = 2.0 gamma = 2.0
# Test with reduction 'none' # Test with reduction 'none'
losses_none = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="none") losses_none = nn.losses.focal_loss(
expected_none = mx.array([[0.0433217, 0.0433217, 0.25751, 0.466273], [0.000263401, 0.147487, 0.0433217, 0.0433217]]) inputs, targets, alpha, gamma, reduction="none"
)
expected_none = mx.array(
[
[0.0433217, 0.0433217, 0.25751, 0.466273],
[0.000263401, 0.147487, 0.0433217, 0.0433217],
]
)
self.assertTrue(mx.allclose(losses_none, expected_none)) self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean' # Test with reduction 'mean'
losses_mean = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="mean") losses_mean = nn.losses.focal_loss(
inputs, targets, alpha, gamma, reduction="mean"
)
expected_mean = mx.mean(expected_none) expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean)) self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum' # Test with reduction 'sum'
losses_sum = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="sum") losses_sum = nn.losses.focal_loss(
inputs, targets, alpha, gamma, reduction="sum"
)
expected_sum = mx.sum(expected_none) expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum)) self.assertTrue(mx.allclose(losses_sum, expected_sum))
@ -258,17 +274,23 @@ class TestNN(mlx_tests.MLXTestCase):
margin = 1.0 margin = 1.0
# Test with reduction 'none' # Test with reduction 'none'
losses_none = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="none") losses_none = nn.losses.contrastive_loss(
embeddings1, embeddings2, targets, margin, reduction="none"
)
expected_none = mx.array([0.2, 0.735425]) expected_none = mx.array([0.2, 0.735425])
self.assertTrue(mx.allclose(losses_none, expected_none)) self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean' # Test with reduction 'mean'
losses_mean = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="mean") losses_mean = nn.losses.contrastive_loss(
embeddings1, embeddings2, targets, margin, reduction="mean"
)
expected_mean = mx.mean(expected_none) expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean)) self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum' # Test with reduction 'sum'
losses_sum = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="sum") losses_sum = nn.losses.contrastive_loss(
embeddings1, embeddings2, targets, margin, reduction="sum"
)
expected_sum = mx.sum(expected_none) expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum)) self.assertTrue(mx.allclose(losses_sum, expected_sum))
@ -278,17 +300,23 @@ class TestNN(mlx_tests.MLXTestCase):
targets = mx.array([1, -1]) targets = mx.array([1, -1])
# Test with reduction 'none' # Test with reduction 'none'
losses_none = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="none") losses_none = nn.losses.cosine_similarity_loss(
embeddings1, embeddings2, targets, reduction="none"
)
expected_none = mx.array([0.0146555, 0.961074]) expected_none = mx.array([0.0146555, 0.961074])
self.assertTrue(mx.allclose(losses_none, expected_none)) self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean' # Test with reduction 'mean'
losses_mean = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="mean") losses_mean = nn.losses.cosine_similarity_loss(
embeddings1, embeddings2, targets, reduction="mean"
)
expected_mean = mx.mean(expected_none) expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean)) self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum' # Test with reduction 'sum'
losses_sum = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="sum") losses_sum = nn.losses.cosine_similarity_loss(
embeddings1, embeddings2, targets, reduction="sum"
)
expected_sum = mx.sum(expected_none) expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum)) self.assertTrue(mx.allclose(losses_sum, expected_sum))