mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
black
This commit is contained in:
parent
80c4630d26
commit
353a4cb1cd
@ -6,7 +6,6 @@ import unittest
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
@ -197,12 +196,18 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
delta = 1.0
|
delta = 1.0
|
||||||
|
|
||||||
# Test with reduction 'none'
|
# Test with reduction 'none'
|
||||||
losses_none = nn.losses.huber_loss(predictions, targets, delta, reduction="none")
|
losses_none = nn.losses.huber_loss(
|
||||||
expected_none = mx.array([0.125, 0.125, 0.125, 0.125]) # Example expected values
|
predictions, targets, delta, reduction="none"
|
||||||
|
)
|
||||||
|
expected_none = mx.array(
|
||||||
|
[0.125, 0.125, 0.125, 0.125]
|
||||||
|
) # Example expected values
|
||||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
# Test with reduction 'mean'
|
# Test with reduction 'mean'
|
||||||
losses_mean = nn.losses.huber_loss(predictions, targets, delta, reduction="mean")
|
losses_mean = nn.losses.huber_loss(
|
||||||
|
predictions, targets, delta, reduction="mean"
|
||||||
|
)
|
||||||
expected_mean = mx.mean(expected_none)
|
expected_mean = mx.mean(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
@ -237,17 +242,28 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
gamma = 2.0
|
gamma = 2.0
|
||||||
|
|
||||||
# Test with reduction 'none'
|
# Test with reduction 'none'
|
||||||
losses_none = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="none")
|
losses_none = nn.losses.focal_loss(
|
||||||
expected_none = mx.array([[0.0433217, 0.0433217, 0.25751, 0.466273], [0.000263401, 0.147487, 0.0433217, 0.0433217]])
|
inputs, targets, alpha, gamma, reduction="none"
|
||||||
|
)
|
||||||
|
expected_none = mx.array(
|
||||||
|
[
|
||||||
|
[0.0433217, 0.0433217, 0.25751, 0.466273],
|
||||||
|
[0.000263401, 0.147487, 0.0433217, 0.0433217],
|
||||||
|
]
|
||||||
|
)
|
||||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
# Test with reduction 'mean'
|
# Test with reduction 'mean'
|
||||||
losses_mean = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="mean")
|
losses_mean = nn.losses.focal_loss(
|
||||||
|
inputs, targets, alpha, gamma, reduction="mean"
|
||||||
|
)
|
||||||
expected_mean = mx.mean(expected_none)
|
expected_mean = mx.mean(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
# Test with reduction 'sum'
|
# Test with reduction 'sum'
|
||||||
losses_sum = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="sum")
|
losses_sum = nn.losses.focal_loss(
|
||||||
|
inputs, targets, alpha, gamma, reduction="sum"
|
||||||
|
)
|
||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
@ -258,17 +274,23 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
margin = 1.0
|
margin = 1.0
|
||||||
|
|
||||||
# Test with reduction 'none'
|
# Test with reduction 'none'
|
||||||
losses_none = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="none")
|
losses_none = nn.losses.contrastive_loss(
|
||||||
|
embeddings1, embeddings2, targets, margin, reduction="none"
|
||||||
|
)
|
||||||
expected_none = mx.array([0.2, 0.735425])
|
expected_none = mx.array([0.2, 0.735425])
|
||||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
# Test with reduction 'mean'
|
# Test with reduction 'mean'
|
||||||
losses_mean = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="mean")
|
losses_mean = nn.losses.contrastive_loss(
|
||||||
|
embeddings1, embeddings2, targets, margin, reduction="mean"
|
||||||
|
)
|
||||||
expected_mean = mx.mean(expected_none)
|
expected_mean = mx.mean(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
# Test with reduction 'sum'
|
# Test with reduction 'sum'
|
||||||
losses_sum = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="sum")
|
losses_sum = nn.losses.contrastive_loss(
|
||||||
|
embeddings1, embeddings2, targets, margin, reduction="sum"
|
||||||
|
)
|
||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
@ -278,17 +300,23 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
targets = mx.array([1, -1])
|
targets = mx.array([1, -1])
|
||||||
|
|
||||||
# Test with reduction 'none'
|
# Test with reduction 'none'
|
||||||
losses_none = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="none")
|
losses_none = nn.losses.cosine_similarity_loss(
|
||||||
|
embeddings1, embeddings2, targets, reduction="none"
|
||||||
|
)
|
||||||
expected_none = mx.array([0.0146555, 0.961074])
|
expected_none = mx.array([0.0146555, 0.961074])
|
||||||
self.assertTrue(mx.allclose(losses_none, expected_none))
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
# Test with reduction 'mean'
|
# Test with reduction 'mean'
|
||||||
losses_mean = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="mean")
|
losses_mean = nn.losses.cosine_similarity_loss(
|
||||||
|
embeddings1, embeddings2, targets, reduction="mean"
|
||||||
|
)
|
||||||
expected_mean = mx.mean(expected_none)
|
expected_mean = mx.mean(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
# Test with reduction 'sum'
|
# Test with reduction 'sum'
|
||||||
losses_sum = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="sum")
|
losses_sum = nn.losses.cosine_similarity_loss(
|
||||||
|
embeddings1, embeddings2, targets, reduction="sum"
|
||||||
|
)
|
||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user