mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function * Ran pre-commit * Ran pre commit * Removed old sigmoid implementation to match with main * Removed gated activation from __init__.py * Removed unused test cases * Removed unused imports * format / docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -789,6 +789,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_glu(self):
|
||||
x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32)
|
||||
y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32)
|
||||
out = nn.glu(x)
|
||||
self.assertEqualArray(out, y)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
|
Reference in New Issue
Block a user