mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
@@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_module_attributes(self):
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.val = None
|
||||
@@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softmin(self):
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
y = nn.softmin(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.6652, 0.2447, 0.0900])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softplus(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softplus(x)
|
||||
@@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
out = nn.glu(x)
|
||||
self.assertEqualArray(out, y)
|
||||
|
||||
def test_hard_tanh(self):
|
||||
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
|
||||
y = nn.hard_tanh(x)
|
||||
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
|
||||
self.assertTrue(mx.array_equal(y, expected_y))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_hard_shrink(self):
|
||||
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||
y = nn.hard_shrink(x)
|
||||
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||
self.assertTrue(mx.array_equal(y, expected_y))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.hard_shrink(x, lambd=1.0)
|
||||
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
|
||||
self.assertTrue(mx.array_equal(y, expected_y))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
|
Reference in New Issue
Block a user