Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta
2024-06-04 15:48:18 -07:00
committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
14 changed files with 110 additions and 20 deletions

View File

@@ -55,9 +55,7 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_module_attributes(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.val = None
@@ -806,6 +804,15 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softmin(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.softmin(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.2447, 0.0900])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
@@ -899,6 +906,28 @@ class TestLayers(mlx_tests.MLXTestCase):
out = nn.glu(x)
self.assertEqualArray(out, y)
def test_hard_tanh(self):
x = mx.array([1.0, -2.0, 0.0, 0.5, 2.0])
y = nn.hard_tanh(x)
expected_y = mx.array([1.0, -1.0, 0.0, 0.5, 1.0])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_hard_shrink(self):
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
y = nn.hard_shrink(x)
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
y = nn.hard_shrink(x, lambd=1.0)
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs)